BZOJ 3992 [SDOI2015]序列統計 NTT

題意:
存在一個集合S,求長度爲N,每一個元素都是S中的元素(可重複),並且該序列所有數的乘積mod M = x 的序列個數。
M是質數,且集合中的所有元素的範圍都在[0,M-1]內。
並且x!=0
解析:
因爲有M是質數這個特殊條件,所以我們可以求出來M的原根G,之後因爲G的0~(phi(M)-1)可以完美替代0~M-1中的數,於是我們可以考慮把S中所有的數用G的幾次冪來代替。
至於爲什麼這樣考慮。
因爲這樣就把我們所需要的乘法轉化成了冪的加法。
搞出集合S的生成函數。
由於每個數可以選取多次,所以接下來的問題就是S的生成函數的n次冪的對應的第x次冪項。
我們發現過程其實就是多項式的乘積過程,並且題目要求答案mod 一個原根爲3的大質數,所以我們可以考慮用NTT來優化這一過程。
需要注意的是,在多項式乘積的時候,我們每一次要把大於m的係數加到其mod m後的那一項上,也就是說,不要直接消除,而是在乘積的時候把越界的部分轉到mod m下。
總複雜度O(lognmlogm)
代碼:

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define mod 1004535809
#define G 3
#define N 262145
using namespace std;
typedef long long ll;
int n,m,x,s,root;
ll prime[20010];
int pos[17010];
ll a[N],b[N];
int rev[N];
int num[17010];
int tot;
ll mm;
ll quick_my(ll x,ll y,ll MOD)
{
    ll ret=1;
    while(y)
    {
        if(y&1)ret=ret*x%MOD;
        x=x*x%MOD;
        y>>=1;
    }
    return ret;
}
void get_factor(ll x)
{
    tot=0;
    for(ll i=2;i*i<=x;i++)
    {
        if(x%i==0)
        {
            prime[++tot]=i;
            while(x%i==0)x/=i;
        }
    }
    if(x!=1)prime[++tot]=x;
}
bool check(ll x,ll MOD,ll PHI)
{
    for(int i=1;i<=tot;i++)
    {
        if(quick_my(x,PHI/prime[i],MOD)==1)return 0;
    }
    return 1;
}
int find_primitive_root(ll x)
{
    ll tmp=x-1;
    get_factor(tmp);
    for(int i=2;;i++)
    {
        if(check(i,x,tmp))return i;
    }
}
void NTT(ll *a,int f)
{
    for(int i=0;i<m;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
    for(int h=2;h<=m;h<<=1)
    {
        ll wn=quick_my(G,(mod-1)/h,mod);
        for(int i=0;i<m;i+=h)
        {
            ll w=1;
            for(int j=0;j<(h>>1);j++,w=w*wn%mod)
            {
                ll t=w*a[i+j+(h>>1)]%mod;
                a[i+j+(h>>1)]=((a[i+j]-t)%mod+mod)%mod;
                a[i+j]=(a[i+j]+t)%mod;
            }
        }
    }
    if(f==-1)
    {
        for(int i=1;i<(m>>1);i++)swap(a[i],a[m-i]);
        ll inv=quick_my(m,mod-2,mod);
        for(int i=0;i<m;i++)a[i]=a[i]*inv%mod;
    }
}
ll ret[N];
void get_my(int y)
{
    ret[0]=1;
    while(y)
    {
        NTT(b,1);
        if(y&1)
        {
            NTT(ret,1);
            for(int i=0;i<m;i++)ret[i]=ret[i]*b[i]%mod;
            NTT(ret,-1);
            for(int i=m-1;i>=mm-1;i--)ret[i-mm+1]=(ret[i-mm+1]+ret[i])%mod,ret[i]=0;
        }
        for(int i=0;i<m;i++)b[i]=b[i]*b[i]%mod;
        NTT(b,-1);
        for(int i=m-1;i>=mm-1;i--)b[i-mm+1]=(b[i-mm+1]+b[i])%mod,b[i]=0;
        y>>=1;
    }
}
int main()
{
    scanf("%d%d%d%d",&n,&m,&x,&s);
    for(int i=1;i<=s;i++)scanf("%d",&num[i]);
    root=find_primitive_root(m);
    ll tmp=1;
    for(int i=0;i<m-1;i++)
        pos[tmp]=i,tmp=tmp*root%m;
    int l=m*2,L=0;
    mm=m;
    for(m=1;m<=l;m<<=1)L++;
    for(int i=0;i<m;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
    for(int i=1;i<=s;i++)
        if(num[i]!=0)b[pos[num[i]]]++;
    get_my(n);
    printf("%lld\n",ret[pos[x]]);
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章