[DFT][Matrix-Tree定理][高斯消元] LOJ #6271. 生成樹求和 加強版

Solution

可以把每位分開來做。
這樣每一位都是一個長度爲3的循環卷積。
DFT 後就是獨立相乘了,可以矩陣樹定理。
考慮三次單位根ω3,1+ω3+ω32=0
如何證明a=b=0a+bω=0 的充要條件呢。
充分性顯然,必要性考慮

(a+bω)(abω)=a2b2ω2=a2+b2ω+b2=a2+ab+b2=0
判別式Δ=3a2,3b2 ,而3Z/(109+7) 不存在二次剩餘。
(a+bω)1=a+bω2a2ab+b2=abbωa2ab+b2
兩個數相乘只要用ω1ω2 代掉就好了。
#include <bits/stdc++.h>
#define show(x) cerr << #x << " = " << x << endl
using namespace std;
typedef long long ll;
typedef pair<int, int> Pairs;

const int N = 111;
const int M = 10101;

inline char get(void) {
    static char buf[100000], *S = buf, *T = buf;
    if (S == T) {
        T = (S = buf) + fread(buf, 1, 100000, stdin);
        if (S == T) return EOF;
    }
    return *S++;
}
template<typename T>
inline void read(T &x) {
    static char c; x = 0; int sgn = 0;
    for (c = get(); c < '0' || c > '9'; c = get()) if (c == '-') sgn = 1;
    for (; c >= '0' && c <= '9'; c = get()) x = x * 10 + c - '0';
    if (sgn) x = -x;
}

const int MOD = 1000000007;
const int INV2 = (MOD + 1) / 2;
const int INV3 = (MOD + 1) / 3;

inline int Mod(int x) {
    return (x % MOD + MOD) % MOD;
}

struct com {
    int r, i;
    com(int _r = 0, int _i = 0): r(Mod(_r)), i(Mod(_i)) {}
    inline com operator +(com b) {
        return com(r + b.r, i + b.i);
    }
    inline com &operator +=(com b) {
        *this = *this + b; return *this;
    }
    inline com operator -(com b) {
        return com(r - b.r, i - b.i);
    }
    inline com &operator -=(com b) {
        *this = *this - b; return *this;
    }
    inline com operator -(void) {
        return com(-r, -i);
    }
    inline com operator *(com b) {
        int _r = (ll)r * b.r % MOD - (ll)i * b.i % MOD;
        int _i = (ll)r * b.i % MOD + (ll)i * b.r % MOD - (ll)i * b.i % MOD;
        return com(_r, _i);
    }
    inline com operator *(int x) {
        return com((ll)r * x % MOD, (ll)i * x % MOD);
    }
    inline int zero(void) {
        return r == 0 && i == 0;
    }
} w[3];
struct matrix {
    com a[N][N];
    inline com *operator [](int x) {
        return a[x];
    }
} G[3];
int n, m, ans;
int x[M], y[M], z[M];
com ad[3][3];

inline int pwr(int a, int b) {
    int c = 1;
    while (b) {
        if (b & 1) c = (ll)c * a % MOD;
        b >>= 1; a = (ll)a * a % MOD;
    }
    return c;
}
inline int inv(int x) {
    return pwr(x % MOD, MOD - 2);
}
inline com inv(com x) {
    int a = x.r, b = x.i;
    int down = inv(((ll)a * a % MOD - (ll)a * b % MOD + (ll)b * b % MOD) % MOD + MOD);
    return com((ll)(a - b + MOD) * down % MOD, (ll)(MOD - b) * down % MOD);
}
inline void add(int &x, int a) {
    x = (x + a) % MOD;
}

inline void dft(com *a) {
    static com b[N];
    b[0] = a[0] * w[0] + a[1] * w[0] + a[2] * w[0];
    b[1] = a[0] * w[0] + a[1] * w[1] + a[2] * w[2];
    b[2] = a[0] * w[0] + a[1] * w[2] + a[2] * w[1];
    a[0] = b[0]; a[1] = b[1]; a[2] = b[2];
}
inline void idft(com *a) {
    static com b[N];
    b[0] = a[0] * w[0] + a[1] * w[0] + a[2] * w[0];
    b[1] = a[0] * w[0] + a[1] * w[2] + a[2] * w[1];
    b[2] = a[0] * w[0] + a[1] * w[1] + a[2] * w[2];
    a[0] = b[0] * INV3;
    a[1] = b[1] * INV3;
    a[2] = b[2] * INV3;
}
inline void pre(void) {
    w[0] = 1;
    w[1] = w[0] * com(0, 1);
    w[2] = w[1] * com(0, 1);
    com ppp = inv(com(-2, -1));
    ppp = ppp * com(-2, -1);
    for (int i = 0; i < 3; i++) {
        for (int j = 0; j < 3; j++)
            ad[i][j] = (i == j);
        dft(ad[i]);
    }
}
inline com det(matrix a, int n) {
    int f = 0;
    for (int i = 1; i <= n; i++) {
        int k = 0;
        for (int j = i; j <= n; j++)
            if (!a[j][i].zero()) {
                k = j; break;
            }
        if (k == 0) return 0;
        if (i != k) {
            for (int j = i; j <= n; j++)
                swap(a[i][j], a[k][j]);
            f ^= 1;
        }
        for (int j = i + 1; j <= n; j++) {
            com t = inv(a[i][i]) * a[j][i];
            for (int k = i; k <= n; k++)
                a[j][k] -= t * a[i][k];
        }
    }
    com res = 1;
    for (int i = 1; i <= n; i++)
        res = res * a[i][i];
    if (f) res = -res;
    return res;
}

int main(void) {
    freopen("sum.in", "r", stdin);
    freopen("sum.out", "w", stdout);
    read(n); read(m);
    for (int i = 1; i <= m; i++) {
        read(x[i]); read(y[i]); read(z[i]);
    }
    pre();
    for (int k = 0, pw3 = 1; pw3 < 10000; k++, pw3 *= 3) {
        for (int i = 1; i <= n; i++)
            for (int j = 1; j <= n; j++)
                for (int d = 0; d < 3; d++)
                    G[d][i][j] = 0;
        for (int i = 1; i <= m; i++)
            for (int d = 0; d < 3; d++) {
                G[d][x[i]][x[i]] += ad[z[i] % 3][d];
                G[d][y[i]][y[i]] += ad[z[i] % 3][d];
                G[d][x[i]][y[i]] -= ad[z[i] % 3][d];
                G[d][y[i]][x[i]] -= ad[z[i] % 3][d];
            }
        com res[3];
        for (int d = 0; d < 3; d++)
            res[d] = det(G[d], n - 1);
        idft(res);
        add(ans, 1ll * res[1].r * pw3 % MOD);
        add(ans, 2ll * res[2].r * pw3 % MOD);
        for (int i = 1; i <= m; i++)
            z[i] /= 3;
    }
    cout << ans << endl;
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章