len[i]:節點i的迴文串的長度
next[i][c]:節點i的迴文串在兩邊添加字符c以後變成的迴文串的編號
fail[i]:指向i的最長迴文後綴且不爲i
cnt[i]:節點i表示的迴文串在S中出現的次數(建樹時求出的不是完全的,count()加上子節點以後纔是正確的)
num[i]:以節點i迴文串的末尾字符結尾的但不包含本條路徑上的迴文串的數目。(也就是fail指針路徑的深度)
last:指向以字符串中上一個位置結尾的迴文串
cur: 指向由next構成的樹中當前迴文串的父親節點(即當前迴文串是cur左右兩邊各拓展一個字符得來)
s[i]:第i次添加的字符
p:添加的節點個數
n:添加的字符個數
dfs時,偶節點爲0 ,奇節點爲1,每條連邊由next樹得到,爲nxt[x][c] ,每個節點所代表的爲以該條邊結尾的迴文串。
模板:
#include<bits/stdc++.h>//空間O(N*字符集大小) 時間O(N*log(字符集大小))
#define ll long long
#define maxn 300010 //節點個數
using namespace std;
char a[maxn];
int vis[50];
struct pam
{
int len[maxn]; //len[i] 節點i的迴文串長度
int nxt[maxn][26]; //nxt[i]['c'] 節點i的迴文串在兩邊添加字符c後變成的迴文串編號
int fail[maxn]; //fail[i] 指向i的最長迴文後綴所在的節點 且不爲i
int cnt[maxn]; //cnt[i] 節點i表示的迴文串在S中出現的次數
int s[maxn]; //s[i] 第i次添加的字符
int num[maxn];
int last; //指向以字符串中上一個位置結尾的迴文串的節點
int cur; //指向由next構成的樹中當前迴文串的父親節點(即當前迴文串是cur左右兩邊各拓展一個字符得來)
int p; // 添加的節點個數
int n; // 添加的字符串個數
int newnode(int l) //新建節點
{
for(int i=0;i<=25;i++)nxt[p][i]=0; // 消除子節點
cnt[p]=num[p]=0; //節點p爲新迴文串所以出現次數爲0
len[p]=l;
return p++;
}
inline void init()
{
p=n=last=0;
newnode(0); //偶節點
newnode(-1); //奇節點
s[0]=-1;
fail[0]=1;
}
int get_fail(int x) //找到可以插入的節點
{
while(s[n-1-len[x]]!=s[n])x=fail[x];
return x;
}
inline void add(int c)
{
c-='a';
s[++n]=c;
int cur=get_fail(last); // 找到可以插入的節點並當做父節點
if(!nxt[cur][c])
{
int now=newnode(len[cur]+2);
fail[now]=nxt[get_fail(fail[cur])][c]; //從父節點的迴文後綴開始找,找到一個s[l-1]=s[n]的點則出邊的點爲最長迴文後綴
nxt[cur][c]=now;
num[now]=num[fail[now]]+1;
}
last=nxt[cur][c]; //成爲新的上一個位置
cnt[last]++;
}
void Count() //統計本質相同的迴文串的出現次數。與位置無關
{
for(int i=p-1;i>=0;i--)
{
cnt[fail[i]]+=cnt[i]; //每個節點會比父節點先算完,於是父節點能加到所有的子節點
}
}
}pam;
ll ans=0;
void dfs(int x,int step)
{
//printf("%d %d\n",x,step);
if(x>1)
{
ans+=1ll*pam.cnt[x]*step;
}
for(int i=0;i<='z'-'a';i++)
{
int v=pam.nxt[x][i];
if(!v)continue;
else
{
if(vis[i]==0)
{
vis[i]++;
dfs(v,step+1);
vis[i]--;
}
else
{
vis[i]++;
dfs(v,step);
vis[i]--;
}
}
}
}
int main()
{
scanf("%s",a);
pam.init();
int len=strlen(a);
for(int i=0;i<len;i++)
{
pam.add(a[i]);
}
pam.Count();
memset(vis,0,sizeof(vis));
dfs(0,0);
memset(vis,0,sizeof(vis));
dfs(1,0);
printf("%lld\n",ans);
return 0;
}