題意:給定一棵樹,樹的每個點有點權,定義2個點u和v之間的距離爲u到v的路徑上的點的點權的異或和。求全體點對(u,v):1<=u<=v<=n的距離和。
分析:考慮按位處理距離和。設有ans[i]個點對的距離的第i位爲1,則距離和=ans[0]*2^0+ans[1]*2^1+...+ans[20]*2^20。從而問題轉化爲點權爲0或1的情況。對於轉化後的問題,我是用樹分治處理的:對於子樹u的所有點對路徑,要麼經過點u,要麼不經過點u,經過點u的通過維護點權異或和爲0,1的鏈數來進行統計,不經過點u的遞歸處理。注意應該把點權轉化爲向量來處理,這樣只用跑一遍樹分治,跑20遍會tle。
看完題解後發現不用樹分治,直接樹形dp就可以了。這是因爲每個點只要訪問一次就能得到我們需要的信息,所以不需要每層都把所以點都遍歷一遍。
代碼(樹分治)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
typedef long long LL;
const int maxn=1e5+10,maxl=21;
int n,a[maxn],b[maxn][maxl];
vector<int> G[maxn];
int sz[maxn],root,sum,minmaxs;
LL s[maxl][2],t[maxl][2],ans[maxl];
bool done[maxn];
void dfs_sz(int u,int fu)
{
sz[u]=1;
for (int i=0;i<G[u].size();i++)
{
int v=G[u][i];
if (v==fu||done[v]) continue;
dfs_sz(v,u);sz[u]+=sz[v];
}
}
void dfs_rt(int u,int fu)
{
int maxs=sum-sz[u];
for (int i=0;i<G[u].size();i++)
{
int v=G[u][i];
if (v==fu||done[v]) continue;
dfs_rt(v,u);maxs=max(maxs,sz[v]);
}
if (maxs<minmaxs) {minmaxs=maxs;root=u;}
}
void dfs(int u,int fu,int *o)
{
for (int l=0;l<maxl;l++) t[l][o[l]]++;
for (int i=0;i<G[u].size();i++)
{
int v=G[u][i];
if (v==fu||done[v]) continue;
int o1[maxl];
for (int l=0;l<maxl;l++) o1[l]=o[l]^b[v][l];
dfs(v,u,o1);
}
}
void solve(int u)
{
dfs_sz(u,-1);
minmaxs=maxn;sum=sz[u];
dfs_rt(u,-1);
u=root;done[u]=1;
//cout<<u<<endl;
memset(s,0,sizeof(s));
for (int i=0;i<G[u].size();i++)
{
int v=G[u][i];
if (done[v]) continue;
memset(t,0,sizeof(t));
dfs(v,u,b[v]);
for (int l=0;l<maxl;l++)
{
if (b[u][l])
{
ans[l]+=t[l][0];
ans[l]+=t[l][0]*s[l][0]+t[l][1]*s[l][1];
}
else
{
ans[l]+=t[l][1];
ans[l]+=t[l][0]*s[l][1]+t[l][1]*s[l][0];
}
s[l][0]+=t[l][0];s[l][1]+=t[l][1];
}
}
for (int i=0;i<G[u].size();i++)
{
int v=G[u][i];
if (done[v]) continue;
solve(v);
}
}
int main()
{
LL ret=0;
cin>>n;
for (int i=1;i<=n;i++) scanf("%d",&a[i]),ret+=a[i];
for (int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
G[u].push_back(v);G[v].push_back(u);
}
for (int i=1;i<=n;i++)
for (int j=0;j<maxl;j++)
b[i][j]=(a[i]>>j)&1;
solve(1);
for (int l=0;l<maxl;l++) ret+=(1<<l)*ans[l];
cout<<ret;
return 0;
}
代碼(樹形dp)
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn=1e5+10,maxl=21;
int n,a[maxn],f[maxn][maxl][2];
vector<int> G[maxn];
LL ans[maxl];
void dp(int u,int fu)
{
LL s[maxl][2];memset(s,0,sizeof(s));
for (int i=0;i<G[u].size();i++)
{
int v=G[u][i];
if (v==fu) continue;
dp(v,u);
for (int l=0;l<maxl;l++)
{
if (a[u]&(1<<l))
ans[l]+=f[v][l][0]+f[v][l][0]*s[l][0]+f[v][l][1]*s[l][1];
else
ans[l]+=f[v][l][1]+f[v][l][0]*s[l][1]+f[v][l][1]*s[l][0];
//if (l==10&&u==2) cout<<f[v][l][1]<<" "<<ans[l]<<endl;
s[l][0]+=f[v][l][0];
s[l][1]+=f[v][l][1];
}
}
for (int l=0;l<maxl;l++)
{
int d=(a[u]>>l)&1;
f[u][l][0]=s[l][0^d];f[u][l][1]=s[l][1^d];f[u][l][d]++;
//if (l==10&&u==3) cout<<d<<endl;
}
}
int main()
{
LL ret=0;
cin>>n;
for (int i=1;i<=n;i++) scanf("%d",&a[i]),ret+=a[i];
for (int i=1;i<n;i++)
{
int u,v;scanf("%d%d",&u,&v);
G[u].push_back(v);G[v].push_back(u);
}
dp(1,-1);
//for (int l=0;l<maxl;l++) cout<<ans[l]<<" ";
for (int l=0;l<maxl;l++) ret+=ans[l]*(1<<l);
cout<<ret;
return 0;
}