FFT,NTT模板

FFT

#include<bits/stdc++.h>
using namespace std;

#define ll long long
const int maxn = 5e5 + 5;
const int inf = 0x3f3f3f3f;
const int mod = 1e9 + 7;
typedef complex<double> cp;
const double PI = acos(-1);

char sa[maxn], sb[maxn];
int n = 1, lena, lenb, res[maxn];
cp a[maxn], b[maxn], omg[maxn], inv[maxn];

void init() {
    for (int i = 0; i < n; i ++) {
        omg[i] = cp(cos(2*PI*i/n), sin(2*PI*i/n));
        inv[i] = conj(omg[i]);
    }
}

void fft(cp *a, cp *omg) {
    int lim = 0;
    while((1<<lim) < n) lim++;
    for (int i = 0; i < n; i ++) {
        int t = 0;
        for (int j = 0; j < lim; j ++)
            if((i>>j) & 1) t |= (1<<(lim-j-1));
        if(i < t) swap(a[i], a[t]);
    }
    for (int l = 2; l <= n; l *= 2) { 
        int m = l / 2;
        for (cp *p = a; p != a + n; p += l) 
            for (int i = 0; i < m; i ++) {
                cp t = omg[n/l*i] * p[i+m];
                p[i+m] = p[i] - t;
                p[i] += t;
            }
    }
}


int main() {
    scanf("%d", &n);
    scanf("%s%s", sa, sb);
    lena = lenb = n;
    n = 1;
    while(n < lena + lenb) n <<= 1;
    for (int i = 0; i < lena; i ++) 
        a[i].real(sa[lena-1-i] - '0');
    for (int i = 0; i < lenb; i ++)
        b[i].real(sb[lenb-1-i] - '0');
    init();
    fft(a, omg);fft(b, omg); 
    for (int i = 0; i < n; i ++) 
        a[i] *= b[i];
    fft(a, inv);
    for (int i = 0; i < n; i ++) {
        res[i] += floor(a[i].real()/n + 0.5);
        res[i+1] += res[i] / 10;
        res[i] %= 10;
    }
    int pos = n - 1;
    while(!res[pos]) pos--;
    for (int i = pos; i >= 0; i --) printf("%d", res[i]);
    puts("");
    return 0;
}

NNT

#include<bits/stdc++.h>
using namespace std;

#define ll long long
const int maxn = 1e5 + 5;
const int inf = 0x3f3f3f3f;
const int mod = 998244353;
#define Mod(x) ((x)>=mod?(x)-mod:(x))
#define g 3

int rnk[maxn];
ll a[maxn], b[maxn];


ll Ksm(ll a, ll b) {
    ll res = 1;
    while(b) {
        if(b & 1) res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}

void init(int n) {
    memset(rnk, 0, sizeof(rnk));
    int lim = 0;
    while((1<<lim) < n) lim ++;
    for (int i = 0; i < n; i ++) 
        rnk[i] = (rnk[i>>1]>>1) | ((i&1) << (lim-1));
}


void ntt(ll *a, int op, int n) {
    for (int i = 0; i < n; i ++) 
        if(i < rnk[i]) swap(a[i], a[rnk[i]]);
    for (int i = 2; i <= n; i <<= 1) {
        int nw = Ksm(g, (mod-1)/i);
        if(op == -1) nw = Ksm(nw, mod-2);
        for (int j = 0, m = i >> 1; j < n; j += i)  
            for (int k = 0, w = 1; k < m; k ++) {
                int t = 1ll * a[j+k+m] * w % mod;
                a[j+k+m] = Mod(a[j+k]-t+mod);
                a[j+k] = Mod(a[j+k]+t);
                w = 1ll * w * nw % mod;
            }
    }
    if(op == -1) 
        for (int i = 0, inv = Ksm(n, mod-2); i < n; i ++)
            a[i] = 1ll * a[i] * inv % mod;
}

char s1[maxn], s2[maxn];
ll ans[maxn];

int main() {
    scanf("%s", s1);
    scanf("%s", s2);
    int len1 = strlen(s1), len2 =strlen(s2);
    int n = 1;
    while(n < len1 + len2) n <<= 1;
    init(n);
    for (int i = 0; i < len1; i ++) a[len1-i-1] = s1[i]-'0';
    for (int i = 0; i < len2; i ++) b[len2-i-1] = s2[i]-'0';
    ntt(a, 1, n); ntt(b, 1, n);
    for (int i = 0; i < n; i ++) 
        a[i] = (1ll * a[i] * b[i]) % mod;
    ntt(a, -1, n);
    for (int i = 0; i < n; i ++) 
        cout << a[i] << " ";
    cout << endl;
    for (int i = 0; i < n; i ++) {
        ans[i+1] += ans[i] / 10;
        ans[i] %= 10;
    }
    int pos = n-1;
    while(!a[pos]) pos--;
    for (int i = pos; i >= 0; i --) cout << a[i];
    cout << endl;
    return 0;
}

三模NTT

#include<bits/stdc++.h>

#define swap(x,y) x ^= y, y ^= x, x ^= y
#define ll long long 
const ll maxn = 3 * 1e6 + 10;
using namespace std;

const ll P1 = 469762049, P2 = 998244353, P3 = 1004535809, g = 3; 
const ll PP = 1ll * P1 * P2;
ll n, m, p, len = 1, lim;
ll a[maxn], b[maxn], tmp1[maxn], tmp2[maxn], ans[3][maxn], r[maxn];
ll res[maxn], tmp[maxn], base[maxn];


/*
    傳的參數n,m都比實際個數少一
    n--;m--;
    輸入兩個數n=1
    輸入一個數n=0;
 */
ll Mul(ll a, ll b, ll mod) { //快速乘
    a %= mod, b %= mod;
    return ((a * b - (ll)((ll)((long double)a / mod * b + 1e-3) * mod)) % mod + mod) % mod;
}

ll Ksm(ll a, ll p, ll mod) { //快速冪
    ll base = 1;
    while(p) {
        if(p & 1) base = 1ll * a * base % mod;
        a = 1ll * a * a % mod; p >>= 1;
    }
    return base % mod;
} 

void init(ll n) { //初始化,傳入alen+blen,得到最小的len
    len = 1; lim = 0;
    while(len <= n) len <<= 1, lim++;
    for(ll i = 0; i <= len; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lim - 1));
}

void ntt_Mod(ll *a, const ll n, const ll type, const ll mod) { //ntt
    for(ll i = 0; i < n; i++) if(i < r[i]) swap(a[i], a[r[i]]);
    for(ll mid = 1; mid < n; mid <<= 1) {
        ll W = Ksm(type == 1 ? g : Ksm(g, mod - 2, mod) , (mod - 1) / (mid << 1), mod);
        for(ll j = 0; j < n; j += (mid << 1)) {
            ll w = 1;
            for(ll k = 0; k <mid; k++, w = 1ll * w * W % mod) {
                ll x = a[j + k], y = 1ll * w * a[j + k + mid] % mod;
                a[j + k] = (x + y) % mod,
                a[j + k + mid] = (x - y + mod) % mod;
            }
        }
    }
    if(type == -1) {
        ll inv = Ksm(n, mod - 2, mod);
        for(ll i = 0; i < n; i++) 
            a[i] = 1ll * a[i] * inv % mod;
    }
}

void Out(ll *a, ll len) {
    for (int i = 0; i <= len; i ++) 
        cout << a[i] << " ";
    cout << endl;
}

int ntt_Mul(ll *a, ll *b, ll alen, ll blen, ll mod) {
    init(alen + blen);
    memcpy(tmp1, a, sizeof(tmp1)); memcpy(tmp2, b, sizeof(tmp2));
    ntt_Mod(tmp1, len, 1, P1); ntt_Mod(tmp2, len, 1, P1);
    for(ll i = 0; i <= len; i++) ans[0][i] = 1ll * tmp1[i] * tmp2[i] % P1;
    
    memcpy(tmp1, a, sizeof(tmp1)); memcpy(tmp2, b, sizeof(tmp2));
    ntt_Mod(tmp1, len, 1, P2); ntt_Mod(tmp2, len, 1, P2);
    for(ll i = 0; i <= len; i++) ans[1][i] = 1ll * tmp1[i] * tmp2[i] % P2;
    
    memcpy(tmp1, a, sizeof(tmp1)); memcpy(tmp2, b, sizeof(tmp2));
    ntt_Mod(tmp1, len, 1, P3); ntt_Mod(tmp2, len, 1, P3);
    for(ll i = 0; i <= len; i++) ans[2][i] = 1ll * tmp1[i] * tmp2[i] % P3;
    
    ntt_Mod(ans[0], len, -1, P1);
    ntt_Mod(ans[1], len, -1, P2);
    ntt_Mod(ans[2], len, -1, P3);
    
    for(ll i = 0; i <= alen + blen; i++) {
        ll t = (Mul(1ll * ans[0][i] * P2 % PP, Ksm(P2 % P1, P1 - 2, P1), PP) + 
                Mul(1ll * ans[1][i] * P1 % PP, Ksm(P1 % P2, P2 - 2, P2), PP) ) % PP;
        ll K = ((ans[2][i] - t) % P3 + P3) % P3 * Ksm(PP % P3, P3 - 2, P3) % P3;
        a[i] = (t % mod + ((K % mod) * (PP % mod)) % mod ) % mod;         
    }
    return alen + blen;
}

int ntt_Ksm(ll *a, ll b, int blen, ll mod) {
    memcpy(base, a, sizeof(base));
    memset(a, 0, maxn*sizeof(a));
    a[0] = 1; int alen = 0;
    while(b) {
        if(b & 1) alen = ntt_Mul(a, base, alen, blen, mod);
        memcpy(tmp, base, sizeof(tmp));
        blen = ntt_Mul(base, tmp, blen, blen, mod);
        b >>= 1;
    }
    return alen;
}

int main() {
    scanf("%lld %lld", &n, &p);
    for(ll i = 0; i <= n; i++) scanf("%lld", &a[i]);
    int anslen = ntt_Ksm(a, 5, n, p);
    for (int i = 0; i <= anslen; i ++)
        cout << a[i] << " ";
    cout << endl;
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章