[Comet OJ - Contest #6 E]字符串

Description

給出一個長度爲n的字符串S,定義f(S)爲S的所有n(n+1)/2n*(n+1)/2個子串,兩兩求LCP的和
對於每個i,求出f(S[i…n]),答案對998244353取模
n<=200000

Solution

log^2的做法有很多這裏就不一一說了
數據結構學傻了.jpg
先考慮兩個後綴l和r的所有前綴互相匹配的答案,顯然只和後綴長度和LCP有關
設LCP爲x,後綴長度爲L和R,那麼答案爲G(l,r,x)=i=1Lj=1Rmin(i,j,x)G(l,r,x)=\sum_{i=1}^{L}\sum_{j=1}^{R}min(i,j,x)
稍微化一下式子,設A=2x33x2+x6,B=xx22,C=xA={2x^3-3x^2+x\over 6},B={x-x^2\over 2},C=x,那麼G(l,r,x)=A+B(L+R)+CLRG(l,r,x)=A+B(L+R)+CLR
考慮在後綴樹上,每次合併兩個後綴集合,那麼這兩個後綴集合的LCP就是一個常數,也就是A,B,C都是常數
由於要對每個後綴都求答案,只需要把每一對後綴的貢獻掛在較小的那個後綴上即可
顯然可以發現,對於一個後綴x的答案f(x)f(x)可以寫成f(x)=k(nx+1)+bf(x)=k(n-x+1)+b這樣的一次函數的形式,於是我們只需要對每個位置維護一次函數的係數即可
對於一個後綴L,如果插入了一個後綴R,那麼增量爲A+B(L+R)+CLR=L(B+CR)+A+BRA+B(L+R)+CLR=L(B+CR)+A+BR,只和後綴R的長度以及個數有關
考慮用線段樹合併解決,每個區間維護區間內後綴的長度和及個數,在合併的時候將用右區間的信息去貢獻左區間的一次函數即可做到O(n log n)

Code

#include <set>
#include <vector>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
#define pb(a) push_back(a)
using namespace std;

typedef long long ll;
typedef set<int> :: iterator it;

const int N=4e5+5,M=1e7+5,Mo=998244353;

int n,pos[N],p[N];
ll an[N];

namespace SAM{
	int son[N][26],len[N],pr[N],tot,lst;
	char s[N];

	int extend(int p,int x) {
		int np=++tot;len[np]=len[p]+1;
		for(;p&&!son[p][x];p=pr[p]) son[p][x]=np;
		if (!p) pr[np]=1;
		else {
			int q=son[p][x];
			if (len[q]==len[p]+1) pr[np]=q;
			else {
				int nq=++tot;
				fo(i,0,25) son[nq][i]=son[q][i];
				pr[nq]=pr[q];len[nq]=len[p]+1;
				pr[q]=pr[np]=nq;
				for(;p&&son[p][x]==q;p=pr[p]) son[p][x]=nq;
			}
		}
		return np;
	}

	void init() {
		scanf("%s",s+1);n=strlen(s+1);
		lst=tot=1;
		fd(i,n,1) pos[i]=lst=extend(lst,s[i]-'a');
	}
}

int cnt[M],ls[M],rs[M],rt[N],tot;
ll sum[M],tgk[M],tgb[M];

void insert(int &v,int l,int r,int x) {
	if (!v) v=++tot;
	cnt[v]++;sum[v]+=n-x+1;
	if (l==r) return;
	int mid=l+r>>1;
	if (x<=mid) insert(ls[v],l,mid,x);
	else insert(rs[v],mid+1,r,x);
}

void down(int v) {
	if (tgk[v]) {
		if (ls[v]) tgk[ls[v]]+=tgk[v];
		if (rs[v]) tgk[rs[v]]+=tgk[v];
		tgk[v]=0;
	}
	if (tgb[v]) {
		if (ls[v]) tgb[ls[v]]+=tgb[v];
		if (rs[v]) tgb[rs[v]]+=tgb[v];
		tgb[v]=0;
	}
}

void upd(int x,int y,ll L) {
	ll A=((L*L*L*2+L*L*3+L)/6-L*L)%Mo;
	ll B=((L*L+L)/2-L*L)%Mo,C=L;
	(tgk[x]+=B*cnt[y]+C*sum[y])%=Mo;
	(tgb[x]+=A*cnt[y]+B*sum[y])%=Mo;
} 

int merge(int x,int y,int L) {
	if (!x||!y) return x+y;
	down(x);down(y);
	upd(ls[x],rs[y],L);
	upd(ls[y],rs[x],L);
	ls[x]=merge(ls[x],ls[y],L);
	rs[x]=merge(rs[x],rs[y],L);
	cnt[x]=cnt[ls[x]]+cnt[rs[x]];
	sum[x]=sum[ls[x]]+sum[rs[x]];
	return x;
}

vector<int> son[N];

void dfs(int x) {
	if (p[x]) insert(rt[x],1,n,p[x]);
	for(int y:son[x]) {
		dfs(y);
		rt[x]=merge(rt[x],rt[y],SAM::len[x]);
	}
}

ll calc(ll l) {
	ll A=(l*l*l*2-l*l*3+l)/6%Mo;
	ll B=(-l*l+l)%Mo,C=l;
	return (A+B*l+C*l*l)%Mo;
}

void get_ans(int v,int l,int r) {
	if (l==r) {
		an[l]=(tgk[v]*(n-l+1)+tgb[v])%Mo;
		an[l]=(an[l]*2+calc(n-l+1))%Mo;
		return;
	}
	int mid=l+r>>1;down(v);
	get_ans(ls[v],l,mid);get_ans(rs[v],mid+1,r);
}

int main() {
	//freopen("e.in","r",stdin);
	//freopen("e.out","w",stdout);
	SAM::init();
	fo(i,2,SAM::tot) son[SAM::pr[i]].pb(i);
	fo(i,1,n) p[pos[i]]=i;
	dfs(1);get_ans(rt[1],1,n);
	fd(i,n,1) (an[i]+=an[i+1])%=Mo;
	fo(i,1,n) printf("%lld ",(an[i]+Mo)%Mo);
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章