HDU4918 Query on the subtree 點分治+樹狀數組

bobo has a tree, whose vertices are conveniently labeled by 1,2,…,n. At the very begining, the i-th vertex is assigned with weight w i.

There are q operations. Each operations are of the following 2 types:

Change the weight of vertex v into x (denoted as "! v x"),
Ask the total weight of vertices whose distance are no more than d away from vertex v (denoted as "? v d").

Note that the distance between vertex u and v is the number of edges on the shortest path between them.

InputThe input consists of several tests. For each tests:

The first line contains n,q (1≤n,q≤10 5). The second line contains n integers w 1,w 2,…,w n (0≤w i≤10 4). Each of the following (n - 1) lines contain 2 integers a i,b i denoting an edge between vertices a i and b i (1≤a i,b i≤n). Each of the following q lines contain the operations (1≤v≤n,0≤x≤10 4,0≤d≤n).
OutputFor each tests:

For each queries, a single number denotes the total weight.Sample Input

4 3
1 1 1 1
1 2
2 3
3 4
? 2 1
! 1 0
? 2 1
3 3
1 2 3
1 2
1 3
? 1 0
? 1 1
? 1 2

Sample Output

3
2
1
6
6

題意:給你一棵樹,N個點,每個點一個權值,然後Q組操作(共兩種),第一種是求導一個節點距離不超過d的所有點的權值和是多少;第二種操作時,修改一個點的權值;
題解:多次進行操作1。此時不能每次都O(NlogN)了,太慢了。我們考慮到對於點分治,樹的重心一共有logN層,第一層爲整棵樹的重心,第二層爲第一層重心的子樹的重心,以此類推,每次至少分成兩個大小差不多的子樹,所以
一共有logN層。而且,對於一個點,他最多隻屬於logN個子樹,也就是最多隻屬於logN個重心。所以我們可以預處理出每個點所屬於的重心以及到這些重心的距離,以每個重心建樹狀數組,每個點按照到重心的距離插入到樹狀數組中,
然後每次查詢到u距離不超過d的點的個數就通過樹狀數組求前綴和得到。假設一個重心x到u的距離爲dis,那麼便統計到重心x距離不超過d-dis的點的個數,這個過程我們稱之爲“借力”,本身能力有限,所以需要藉助x的影響力。因爲
如果這個重心被u借力了,那麼這個重心的子重心一定也被借力,由於相鄰被借力的兩個重心x、y所統計的點會有重複,所以我們需要去重。去重的話我們就通過對每個節點再開一個v對x的樹狀數組,這個樹狀數組的意義爲:重心x的子
樹v的重心爲y時,子樹v中每個點到x的距離爲下標建立的樹狀數組。因爲重心x與重心y交集的部分,重心x包括的部分重心y一定包括,所以統計的時候減去v對x的樹狀數組中距x不超過d-dis的點的個數即可。訪問u所屬與的所有重心,
挨個借力,同時去重,便能得到距離u不超過d的點的個數。因爲重心最多logN層,每個樹狀數組最多N個點,logN複雜度的統計,所以每次查詢複雜度O(logN*logN)。我們最多爲每個節點開2個樹狀數組,而且每一層所有樹狀數組的
大小相加不超過N,所以樹狀數組的佔用空間爲O(2NlogN)。
在上面的基礎上稍做擴充。預處理的時候插入樹狀數組的就是該點的權值,查詢依舊是統計前綴和。修改點權值的時候,便是和查詢一樣,在u距重心x距離d的位置在x的樹狀數組中修改u的權值,同時修改u屬於重心x的子樹v的v對x的樹
狀數組中相同位置的值。複雜度和查詢一樣爲O(logN*logN)。

參考代碼:
#include<bits/stdc++.h>
using namespace std;
#define pii pair<int,int>
#define mkp make_pair
#define lowbit(x) (x&-x)

typedef long long ll;
const int INF=0x3f3f3f3f;
const int maxn=1e5+10;
int n,q,w[maxn];
char op[2];
struct MSG{
    int id1,id2;
    int dep;
} msg[maxn][17];
struct Edge{
    int v,nxt;
} edge[maxn<<1];
int vis[maxn],head[maxn],tot;
int root,siz[maxn],mx[maxn],fa[maxn],S;
int maxfloor[maxn],idn,L[maxn<<1],R[maxn<<1];
int c[maxn*17];
void Init()
{
    S=n;idn=0;
    tot=root=0;
    memset(c,0,sizeof c);
    memset(vis,0,sizeof vis);
    memset(w,0,sizeof w);
    memset(head,-1,sizeof head);
    memset(msg,0,sizeof msg);
}
void AddEdge(int x,int y)
{
    edge[tot].v=y;
    edge[tot].nxt=head[x];
    head[x]=tot++;
}

void Add(int l,int r,int pos,int val)
{
    while(pos<r-l)
        c[l+pos]+=val,pos+=lowbit(pos);
}
int Sum(int l,int r,int len)
{
    if(len<1) return 0;
    if(len>r-l-1) len=r-l-1;
    int res=0;
    while(len) res+=c[l+len],len-=lowbit(len);
    return res;
}

void getroot(int u,int fa)
{
    siz[u]=1;mx[u]=0;
    for(int i=head[u];~i;i=edge[i].nxt)
    {
        int v=edge[i].v;
        if(vis[v]||v==fa) continue;
        getroot(v,u);
        siz[u]+=siz[v];
        mx[u]=max(mx[u],siz[v]);
    }
    mx[u]=max(mx[u],S-siz[u]);
    if(mx[u]<mx[root]) root=u;
}

int getmaxdep(int u,int fa)
{
    int res=1;
    for(int i=head[u];~i;i=edge[i].nxt)
    {
        int v=edge[i].v;
        if(vis[v]||v==fa) continue;
        res=max(res,1+getmaxdep(v,u));
    }
    return res;
}
void dfs(int u,int fa,int deep,int id,int flor,int tp)
{
    if(!tp) msg[u][flor].id1=id;
    else msg[u][flor].id2=id;
    msg[u][flor].dep=deep;
    Add(L[idn],R[idn],deep,w[u]);
    for(int i=head[u];~i;i=edge[i].nxt)
    {
        int v=edge[i].v;
        if(!vis[v]&&v!=fa) dfs(v,u,deep+1,id,flor,tp);
    }
}
void solve(int u,int s,int flor)
{
    vis[u]=1; maxfloor[u]=flor;
    idn++; L[idn]=R[idn-1];
    R[idn]=L[idn]+getmaxdep(u,0)+1;
    dfs(u,0,1,idn,flor,0);
    msg[u][flor].id2=-1;
    for(int i=head[u];~i;i=edge[i].nxt)
    {
        int v=edge[i].v;
        if(vis[v]) continue;
        idn++;L[idn]=R[idn-1];
        R[idn]=L[idn]+getmaxdep(v,u)+2;
        dfs(v,u,2,idn,flor,1);
    }

     for(int i=head[u];~i;i=edge[i].nxt)
    {
        int v=edge[i].v;
        if(vis[v]) continue;
        S=siz[v]; root=0;
        if(siz[v]>siz[u]) S=s-siz[u];
        getroot(v,u);
        solve(root,siz[v],flor+1);
    }
}

int main()
{
    while(~scanf("%d%d",&n,&q))
    {
        Init();
        for(int i=1;i<=n;++i) scanf("%d",w+i);
        for(int i=1;i<n;++i)
        {
            int x,y;
            scanf("%d%d",&x,&y);
            AddEdge(x,y);AddEdge(y,x);
        }
        mx[root]=INF;
        getroot(1,0);
        solve(1,S,0);
        while(q--)
        {
            int x,y,ans;
            scanf("%s%d%d",&op,&x,&y);
            if(op[0]=='?')
            {
                ans=0;
                for(int f=0;f<=maxfloor[x];++f)
                {
                    int id1=msg[x][f].id1;
                    int id2=msg[x][f].id2;
                    int dep=msg[x][f].dep;
                    ans+=Sum(L[id1],R[id1],y+2-dep);
                    if(id2!=-1) ans-=Sum(L[id2],R[id2],y+2-dep);
                }
                printf("%d\n",ans);
            }
            else
            {
                for(int f=0;f<=maxfloor[x];++f)
                {
                    int id1=msg[x][f].id1;
                    int id2=msg[x][f].id2;
                    int dep=msg[x][f].dep;
                    Add(L[id1],R[id1],dep,y-w[x]);
                    if(id2!=-1) Add(L[id2],R[id2],dep,y-w[x]);
                }
                w[x]=y;
            }
        }
    }

    return 0;
}
View Code

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章