【主席樹】【離線+樹狀數組】小崔的禮物

題意

有一棵樹,每次詢問u到v的路徑上在點權在【l,r】內的點權和

分析

好像正解是主席樹
將每個點O(logn)建在父節點上面(主席樹新技能get)
//如果建在前一個點上,記錄的是一顆子樹點權在【l,r】內的點權和
記錄到根的點權在【l,r】內的點權和
詢問時答案爲

o=Lca(u,v);
Query(rt[u])+Query(rt[v])-Query(rt[o])-Query(rt[fa[o]]);

時間O(n*logn)

然而有一種水法
離線+樹狀數組
樹狀數組維護一條鏈上的點權和

處理到點x時
先更新樹狀數組
把x的詢問都處理完
處理兒子
最後刪除更新
退出

比正解跑得快(主席樹常數大)

Q:爲什麼要樹狀數組?
A:因爲有【l,r】的區間詢問(而且因爲在樹上,不能排序,否則就用莫隊了)

主席樹

#include<cstdio>
#include<cmath>
#include<algorithm>
#define N 100010
using namespace std;
typedef long long LL;
struct Tree{
    LL v;
    int l,r;
}t[N*25];
struct Edge{
    int p,q,n;
}b[N*2];
struct Node{
    int v,id;
}c[N];
int a[N],n,m,ln,num,p,q,u,v,h[N],d[N],st[N][25],rt[N],o,top,ans;
bool cmp(Node p,Node q){
    return p.v<q.v;
}
void ljb(int p,int q){
    b[++num]=(Edge){p,q,h[p]};
    h[p]=num;
}
void ST(){
    for(int i=1;i<=ln;i++)
        for(int j=1;j<=n;j++)
            st[j][i]=st[st[j][i-1]][i-1];
}
int Lca(int p,int q){
    if(d[p]<d[q])swap(p,q);
    for(int i=ln;i>=0;i--)
        if(d[st[p][i]]>=d[q])p=st[p][i];
    if(p!=q){
        for(int i=ln;i>=0;i--){
            if(st[p][i]!=st[q][i]){
                p=st[p][i];
                q=st[q][i];
            }
            if(st[p][0]==st[q][0])break;
        }
        p=st[p][0];
    }
    return p;
}
void Ins(int& x,int y,int l,int r){
    x=++top;
    t[x]=t[y];
    t[x].v+=c[p].v;
    if(l==r)return;
    int mid=(l+r)>>1;
    if(p<=mid)Ins(t[x].l,t[y].l,l,mid);
    else Ins(t[x].r,t[y].r,mid+1,r);
}
LL Query(int x,int l,int r){
    if(p<=l&&r<=q)return t[x].v;
    int mid=(l+r)>>1;
    LL ans=0;
    if(p<=mid)ans+=Query(t[x].l,l,mid);
    if(q>mid)ans+=Query(t[x].r,mid+1,r);
    return ans;
}
void Dfs(int x){
    int y;
    d[x]=d[st[x][0]]+1;
    p=a[x];
    Ins(rt[x],rt[st[x][0]],1,n);
    for(int i=h[x];i;i=b[i].n){
        y=b[i].q;
        if(y==st[x][0])continue;
        st[y][0]=x;
        Dfs(y);
    }
}
int main(){
    freopen("data.txt","r",stdin);
    freopen("1.txt","w",stdout);
    scanf("%d%d",&n,&m);
    ln=log2(n);
    for(int i=1;i<=n;i++){
        scanf("%d",&c[i].v);
        c[i].id=i;
    }
    sort(c+1,c+n+1,cmp);
    for(int i=1;i<=n;i++)a[c[i].id]=i;
    for(int i=1;i<n;i++){
        scanf("%d%d",&p,&q);
        ljb(p,q);
        ljb(q,p);
    }
    Dfs(1);
    ST();
//  for(int i=1;i<=top;i++)printf("%d %d %d\n",i,t[i].v,t[i].l,t[i].r);
    for(int i=1;i<=m;i++){
        scanf("%d%d%d%d",&u,&v,&p,&q);
        p=lower_bound(c+1,c+n+1,(Node){p,0},cmp)-c;
        q=upper_bound(c+1,c+n+1,(Node){q,0},cmp)-c-1;
//      printf("%d %d %d\n",u,v,o);
//      printf("%d %d\n",p,q);
//      printf("%d %d %d %d\n",Query(rt[u],1,n),Query(rt[v],1,n),Query(rt[o],1,n),Query(rt[st[o][0]],1,n));
        if(p>q){
            printf("0 ");
            continue;
        }
        o=Lca(u,v);
        if(a[o]>=p&&a[o]<=q)ans=c[a[o]].v;
        else ans=0;
        printf("%lld ",Query(rt[u],1,n)+Query(rt[v],1,n)-2*Query(rt[o],1,n)+ans);
    }
}

離線+樹狀數組(Orz xyl)

#include <cstdio>
#include <algorithm>
const int maxn =1e5+10;
#define For(n) for (int i=1;i<=n;++i)
#define rep(i,n) for (int i=1;i<=n;++i)
#define repp(i,x,y) for (int i=x;i<=y;++i)
#define lowbit(x) ((x)&(-x))
#define LL long long
using namespace std;
int dep[maxn],tot,st[maxn<<2],top,n,m,val[maxn],x,y,p,q,head[maxn],g,anc[maxn][20];
int key[maxn<<2],h[maxn];
LL tr[maxn<<2],t1,t2;
struct prob{int x,y,l,r,cnt;LL ans;}a[maxn];
struct Edge{int v,next;}b[maxn<<1];
struct node{int num,x,l,r,next;}c[maxn<<2];

void add(int u,int v){b[++g]=(Edge){v,head[u]};head[u]=g;}

void add2(int k,int u,int l,int r){c[++g]=(node){k,u,l,r,h[u]};h[u]=g;}

/*################LCA##############*/
void dfs(int x)
{
    for (int i=head[x];i;i=b[i].next)
    {
        int v=b[i].v;
        if (anc[v][0]) continue;
        anc[v][0]=x;
        dep[v]=dep[x]+1;
        dfs(v);
    }
    return ;
}

void ST()
{
    repp(j,1,18)
      For(n)
        anc[i][j]=anc[anc[i][j-1]][j-1];
}

int lca(int x,int y)
{
    if (dep[x]<dep[y]) swap(x,y);
    for (int j=18;j>=0;--j)
        if (dep[anc[x][j]]>=dep[y]) x=anc[x][j];
    if (x==y) return x;
    for (int j=18;j>=0;--j)
        if (anc[x][j]^anc[y][j]) x=anc[x][j],y=anc[y][j];
    return anc[x][0];
}

/*################LCA##############*/


/*#################樹狀數組##############*/
void INS(int x,int y)
{
    while (x<=top)
    {
        tr[x]+=y;
        x+=lowbit(x);
    }
}

LL Sum(int x)
{
    LL tmp=0;
    while (x)
    {
        tmp+=tr[x];
        x-=lowbit(x);
    }
    return tmp;
}
/*#################樹狀數組##############*/

void dfs2(int x)
{
    INS(val[x],key[val[x]]);
    for (int i=h[x];i;i=c[i].next)
    {
        LL tmp=Sum(c[i].r)-Sum(c[i].l-1);
        int num=c[i].num;
        a[num].ans+=tmp*(LL)((++a[num].cnt)>2 ?1:-1);
    }
    for (int i=head[x];i;i=b[i].next)
      if (b[i].v^anc[x][0]) dfs2(b[i].v);
    INS(val[x],-key[val[x]]);
}

int main()
{

    scanf("%d%d",&n,&m);
    For(n) scanf("%d",val+i),st[++top]=val[i];
    add(n+1,1);val[n+1]=1;st[++top]=1;
    For(n-1) scanf("%d%d",&x,&y),add(x,y),add(y,x);
    dfs(n+1);ST();
    For(m) scanf("%d%d%d%d",&x,&y,&p,&q),a[i]=(prob){x,y,p,q,0,0},
           st[++top]=p,st[++top]=q;
    sort(st+1,st+1+top);top=unique(st+1,st+1+top)-st;--top;
    For(n) val[i]=lower_bound(st+1,st+1+top,val[i])-st;
    For(m) a[i].l=lower_bound(st+1,st+1+top,a[i].l)-st,
           a[i].r=lower_bound(st+1,st+1+top,a[i].r)-st;
    For(top) key[i]=st[i];
    g=0;
    For(m)
    {
        int L=a[i].l;
        int R=a[i].r;
        int LCA=lca(a[i].x,a[i].y);
        add2(i,a[i].x,L,R);add2(i,a[i].y,L,R);
        add2(i,LCA,L,R);   add2(i,anc[LCA][0],L,R);
    }
    LL hh=Sum(190087);
    hh=Sum(71050);
    dfs2(n+1);
    For(m) printf("%lld ",max(0ll,a[i].ans));
    return 0;
}

造數據

#include<cstdio>
#include<cstdlib>
#include<ctime>
#include<algorithm>
#define MOD 100000
using namespace std;
bool b[MOD+10];
int n,m,p,q;
int main(){
    freopen("data.txt","w",stdout);
    srand((unsigned)time(NULL));
    n=rand()%MOD+1;
    m=rand()%MOD+1;
    printf("%d %d\n",n,m);
    for(int i=1;i<=n;i++)printf("%d ",rand()%MOD+1);
    printf("\n");
    b[1]=1;
    for(int i=1;i<n;i++){
        p=rand()%n+1;
        q=rand()%n+1;
        while(b[p]==b[q])p=rand()%n+1,q=rand()%n+1;
        b[p]=b[q]=1;
        printf("%d %d\n",p,q);
    }
    for(int i=1;i<=m;i++){
        p=rand()%MOD+1;
        q=rand()%MOD+1;
        if(p>q)swap(p,q);
        printf("%d %d %d %d\n",rand()%n+1,rand()%n+1,p,q);
    }
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章