ntt模板

用法很像fft模板,對照着看,雖然我不是很懂他的意思,不過好像是處理fft的精度問題

hdu1402

大致題意:求a*b

測試案例:

input:

1

2

1000

2

output:

2

2000


解題思路:用ntt解

代碼:

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define MAXN 262144

const long long P=50000000001507329LL; // 190734863287 * 2 ^ 18 + 1
//const int P=1004535809; // 479 * 2 ^ 21 + 1
//const int P=998244353; // 119 * 2 ^ 23 + 1
const int G=3;

long long mul(long long x,long long y)
{
    return (x*y-(long long)(x/(long double)P*y+1e-3)*P+P)%P;
}
long long qpow(long long x,long long k,long long p)
{
    long long ret=1;
    while(k)
    {
        if(k&1) ret=mul(ret,x);
        k>>=1;
        x=mul(x,x);
    }
    return ret;
}

long long wn[25];
void getwn()
{
    for(int i=1; i<=18; ++i)
    {
        int t=1<<i;
        wn[i]=qpow(G,(P-1)/t,P);
    }
}

int len;
void NTT(long long y[],int op)
{
    for(int i=1,j=len>>1,k; i<len-1; ++i)
    {
        if(i<j) swap(y[i],y[j]);
        k=len>>1;
        while(j>=k)
        {
            j-=k;
            k>>=1;
        }
        if(j<k) j+=k;
    }
    int id=0;
    for(int h=2; h<=len; h<<=1)
    {
        ++id;
        for(int i=0; i<len; i+=h)
        {
            long long w=1;
            for(int j=i; j<i+(h>>1); ++j)
            {
                long long u=y[j],t=mul(y[j+h/2],w);
                y[j]=u+t;
                if(y[j]>=P) y[j]-=P;
                y[j+h/2]=u-t+P;
                if(y[j+h/2]>=P) y[j+h/2]-=P;
                w=mul(w,wn[id]);
            }
        }
    }
    if(op==-1)
    {
        for(int i=1; i<len/2; ++i) swap(y[i],y[len-i]);
        long long inv=qpow(len,P-2,P);
        for(int i=0; i<len; ++i) y[i]=mul(y[i],inv);
    }
}
void Convolution(long long A[],long long B[],int len1,int len2)
{
    int n=max(len1,len2);
    for(len=1; len<(n<<1); len<<=1);
    for(int i=len1; i<len; ++i)
    {
        A[i]=0;
    }
    for (int i=len2;i<len;i++)
    {
        B[i]=0;
    }

    NTT(A,1);
    NTT(B,1);
    for(int i=0; i<len; ++i)
    {
        A[i]=mul(A[i],B[i]);
    }
    NTT(A,-1);
}

long long A[MAXN],B[MAXN];
char s1[MAXN],s2[MAXN];
int main()
{
    getwn();
    int len1,len2;
    while(scanf("%s%s",s1,s2)==2)
    {
        len1=strlen(s1);
        len2=strlen(s2);
        long long ans=0;
        for(int i=0; i<len1; ++i)
        {
            A[i]=s1[len1-i-1]-'0';
        }
        for(int i=0; i<len2; ++i)
        {
            B[i]=s2[len2-i-1]-'0';
        }
        Convolution(A,B,len1,len2);//?aà?μ?2*nê?max(len1,len2)

        int temp;

        for(int i=0; i<len; ++i)
        {
            A[i+1]+=A[i]/10;
            A[i]%=10;
        }
        len=len1+len2-1;
        while (A[len]<=0&&len>0)
            len--;
        for (int i=len;i>=0;i--)
        {
            printf("%I64d",A[i]);
        }
        printf("\n");
    }
    return 0;
}




發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章