【CF434E】Furukawa Nagisa's Tree 點分治

【CF434E】Furukawa Nagisa's Tree

題意:一棵n個點的樹,點有點權。定義$G(a,b)$表示:我們將樹上從a走到b經過的點都拿出來,設這些點的點權分別爲$z_0,z_1...z_{l-1}$,則$G(a,b)=z_0+z_1k^1+z_2k^2+...+z_{l-1}k^{l-1}$。如果$G(a,b)=X \mod Y$(保證Y是質數),則我們稱(a,b)是好的,否則是壞的。現在想知道,有多少個三元組(a,b,c),滿足(a,b),(b,c),(a,c)都是好的或者都是壞的?

$n\le 10^5,Y\le 10^9$

題解:由於一個點對要麼是好的要麼是壞的,所以我們可以枚舉一下所有符合條件的3元組的情況。不過符合條件需要3條邊都相同,那我們可以反過來,統計不合法的3元組的情況(一共$2^3-2$種情況)。經過觀察我們發現,我們可以在 同時連接兩種顏色的邊 的那個點處統計貢獻,即把三元組的貢獻放到了點上。我們設$in_0(),in_1(i),out_0(i),out_1(i)$表示i有多少個好(壞)邊連入(出),則一個點對答案的貢獻就變成:

$2in_0(i)in_1(i)+2out_0(i)out_1(i)+in_0(i)out_1(i)+in_1(i)out_0(i)$

最後將答案/2即可。

所以現在我們只需要求:對於每個點,有多少好邊連入(連出)。這個用點分治可以搞定,因爲我們容易計算兩個多項式連接起來的結果。本題我採用的是容斥式的點分治。

 

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn=100010;
typedef long long ll;
int n,cnt,tot,mn,rt;
ll X,Y,K,Ki,ans;
ll pw[maxn],pi[maxn],v[maxn],in1[maxn],in0[maxn],out1[maxn],out0[maxn];
int to[maxn<<1],nxt[maxn<<1],head[maxn],vis[maxn],siz[maxn];
struct node
{
	ll x;
	int y;
	node() {}
	node(ll a,int b) {x=a,y=b;}
	bool operator < (const node &a) const {return x<a.x;}
}p[maxn],q[maxn];
inline int rd()
{
	char gc=getchar();	int ret=0;
	while(gc<'0'||gc>'9')	gc=getchar();
	while(gc>='0'&&gc<='9')	ret=ret*10+gc-'0',gc=getchar();
	return ret;
}
inline void add(int a,int b)
{
	to[cnt]=b,nxt[cnt]=head[a],head[a]=cnt++;
}
inline ll pm(ll x,ll y)
{
	ll z=1;
	while(y)
	{
		if(y&1)	z=z*x%Y;
		x=x*x%Y,y>>=1;
	}
	return z;
}
void getrt(int x,int fa)
{
	int i,tmp=0;
	siz[x]=1;
	for(i=head[x];i!=-1;i=nxt[i])	if(!vis[to[i]]&&to[i]!=fa)	getrt(to[i],x),siz[x]+=siz[to[i]],tmp=max(tmp,siz[to[i]]);
	tmp=max(tmp,n-siz[x]);
	if(tmp<mn)	mn=tmp,rt=x;
}
void getp(int x,int fa,int dep,ll s1,ll s2)
{
	s1=(s1*K+v[x])%Y,s2=(s2+v[x]*((!dep)?0:pw[dep-1]))%Y,dep++;
	p[++tot]=node((X-s1+Y)*pi[dep]%Y,x),q[tot]=node(s2,x);
	for(int i=head[x];i!=-1;i=nxt[i])	if(!vis[to[i]]&&to[i]!=fa)
		getp(to[i],x,dep,s1,s2);
}
void calc(int x,int flag,int dep,ll s1,ll s2)
{
	int i,j,cnt;
	tot=0;
	s1=(s1*K+v[x])%Y,s2=(s2+v[x]*((!dep)?0:pw[dep-1]))%Y,dep++;
	p[++tot]=node((X-s1+Y)*pi[dep]%Y,x),q[tot]=node(s2,x);
	for(i=head[x];i!=-1;i=nxt[i])	if(!vis[to[i]])	getp(to[i],x,dep,s1,s2);
	sort(p+1,p+tot+1),sort(q+1,q+tot+1);
	for(cnt=0,i=j=1;i<=tot;i++)
	{
		for(;j<=tot&&q[j].x<=p[i].x;j++)
		{
			if(j==1||q[j].x!=q[j-1].x)	cnt=0;
			cnt++;
		}
		if(j!=1&&q[j-1].x==p[i].x)	out1[p[i].y]+=cnt*flag;
	}
	for(cnt=0,i=j=1;i<=tot;i++)
	{
		for(;j<=tot&&p[j].x<=q[i].x;j++)
		{
			if(j==1||p[j].x!=p[j-1].x)	cnt=0;
			cnt++;
		}
		if(j!=1&&p[j-1].x==q[i].x)	in1[q[i].y]+=cnt*flag;
	}
}
void dfs(int x)
{
	vis[x]=1;
	int i;
	calc(x,1,0,0,0);
	for(i=head[x];i!=-1;i=nxt[i])	if(!vis[to[i]])
	{
		calc(to[i],-1,1,v[x],0);
		tot=siz[to[i]],mn=1<<30,getrt(to[i],x),dfs(rt);
	}
}
int main()
{
	//freopen("cf434E.in","r",stdin);
	
	n=rd(),Y=rd(),K=rd(),X=rd(),Ki=pm(K,Y-2);
	int i,a,b;
	memset(head,-1,sizeof(head));
	for(i=1;i<=n;i++)	v[i]=rd();
	for(i=pw[0]=pi[0]=1;i<=n;i++)	pw[i]=pw[i-1]*K%Y,pi[i]=pi[i-1]*Ki%Y;
	for(i=1;i<n;i++)	a=rd(),b=rd(),add(a,b),add(b,a);
	tot=n,mn=1<<30,getrt(1,0),dfs(rt);
	for(i=1;i<=n;i++)
	{
		in0[i]=n-in1[i],out0[i]=n-out1[i];
		ans+=2*in1[i]*in0[i]+2*out1[i]*out0[i]+in0[i]*out1[i]+in1[i]*out0[i];
	}
	printf("%lld",1ll*n*n*n-ans/2);
	return 0;
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章