[NTT][DP][樹鏈剖分][分治] LOJ #6289. 花朵

Solution

fu,0/1,i 表示u 子樹中選了i 個點,u 選不選的答案。
轉移顯然就是一個卷積的形式。
考慮重鏈剖分。
先把輕兒子捲到根,這樣只需要考慮重鏈。
再考慮重鏈上的DP。
考慮頭尾選不選gu,0/1,0/1 。這也是個卷積,可以分治+FFT。

#include <bits/stdc++.h>
#define show(x) cerr << #x << " = " << x << endl
using namespace std;

const int MOD = 998244353;
const int N = 202020;

typedef long long ll;
typedef vector<int> poly;
struct arr {
    poly v0, v1;
    arr(void) {}
    arr(poly _v0, poly _v1): v0(_v0), v1(_v1) {}
    inline int size(void) const {
        return v0.size();
    }
    inline bool operator <(const arr &b) const {
        return b.size() < size();
    }
};
struct qua {
    poly v00, v01, v10, v11;
};

poly blank;

inline char get(void) {
    static char buf[100000], *S = buf, *T = buf;
    if (S == T) {
        T = (S = buf) + fread(buf, 1, 100000, stdin);
        if (S == T) return EOF;
    }
    return *S++;
}
template<typename T>
inline void read(T &x) {
    static char c; x = 0; int sgn = 0;
    for (c = get(); c < '0' || c > '9'; c = get()) if (c == '-') sgn = 1;
    for (; c >= '0' && c <= '9'; c = get()) x = x * 10 + c - '0';
    if (sgn) x = -x;
}

inline int pwr(int a, int b) {
    int c = 1;
    while (b) {
        if (b & 1) c = (ll)c * a % MOD;
        b >>= 1; a = (ll)a * a % MOD;
    }
    return c;
}
inline int inv(int x) {
    return pwr(x, MOD - 2);
}
inline int sum(int a, int b) {
    a += b;
    return a >= MOD ? a - MOD : a;
}
inline int sub(int a, int b) {
    return a < b ? a - b + MOD : a - b;
}
inline void add(int &x, int a) {
    x = sum(x, a);
}

namespace FNT {
    const int MAXN = 303030;
    int ww[MAXN], iw[MAXN];
    int rev[MAXN];
    int num;
    inline void pre(int n) {
        num = n;
        int g = pwr(3, (MOD - 1) / n);
        ww[0] = iw[0] = 1;
        for (int i = 1; i < num; i++)
            iw[n - i] = ww[i] = (ll)ww[i - 1] * g % MOD;
    }
    inline void fnt(int *a, int n, int f) {
        static int x, y, *w;
        w = (f == 1) ? ww : iw;
        for (int i = 0; i < n; i++)
            if (rev[i] > i)
                swap(a[rev[i]], a[i]);
        for (int i = 1; i < n; i <<= 1)
            for (int j = 0; j < n; j += (i << 1))
                for (int k = 0; k < i; k++) {
                    x = a[j + k];
                    y = (ll)a[j + k + i] * w[num / (i << 1) * k] % MOD;
                    a[j + k] = sum(x, y);
                    a[j + k + i] = sub(x, y);
                }
        if (f == -1){
            int in = inv(n);
            for (int i = 0; i < n; i++)
                a[i] = (ll)a[i] * in % MOD;
        }
    }
}

inline poly operator *(poly a, poly b) {
    if (a.empty() || b.empty())
        return a.empty() ? a : b;
    using namespace FNT;
    static poly c; c.clear();
    static int p[N], q[N];
    int m = a.size() + b.size() - 1;
    c.resize(m);
    if ((ll)a.size() * b.size() <= 10000) {
        for (int i = 0; i < a.size(); i++)
            for (int j = 0; j < b.size(); j++)
                add(c[i + j], (ll)a[i] * b[j] % MOD);
        return c;
    }
    for (int i = 0; i < a.size(); i++) p[i] = a[i];
    for (int i = 0; i < b.size(); i++) q[i] = b[i];
    int l = 1, L = 0;
    for (; l < m; l <<= 1) ++L; --L;
    for (int i = 0; i < l; i++)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << L);
    fnt(p, l, 1); fnt(q, l, 1);
    for (int i = 0; i < l; i++)
        p[i] = (ll)p[i] * q[i] % MOD;
    fnt(p, l, -1);
    for (int i = 0; i < m; i++) c[i] = p[i];
    for (int i = 0; i < l; i++) p[i] = q[i] = 0;
    return c;
}
inline poly operator +(poly a, poly b) {
    static poly c; c.clear();
    c.resize(max(a.size(), b.size()));
    for (int i = 0; i < a.size(); i++)
        c[i] = sum(c[i], a[i]);
    for (int i = 0; i < b.size(); i++)
        c[i] = sum(c[i], b[i]);
    return c;
}
inline qua operator *(qua a, qua b) {
    return qua { a.v00 * b.v00 + a.v00 * b.v10 + a.v01 * b.v00, 
            a.v00 * b.v01 + a.v00 * b.v11 + a.v01 * b.v01,
            a.v10 * b.v00 + a.v10 * b.v10 + a.v11 * b.v00,
            a.v10 * b.v01 + a.v10 * b.v11 + a.v11 * b.v01 };
}
inline arr operator *(arr a, arr b) {
    return arr(a.v0 * b.v0, a.v1 * b.v1);
}

vector<int> G[N];

inline void addEdge(int from, int to) {
    G[from].push_back(to);
    G[to].push_back(from);
}

int n, m, clc;
int fa[N], son[N], size[N], pre[N], erp[N];
int w[N];
arr f[N];

inline void dfs1(int u) {
    size[u] = 1;
    pre[u] = ++clc;
    erp[clc] = u;
    for (int to: G[u]) {
        if (to == fa[u]) continue;
        fa[to] = u; dfs1(to);
        size[u] += size[to];
        if (size[to] > size[son[u]])
            son[u] = to;
    }
}

priority_queue<arr> Q;
inline arr merge(void) {
    static arr a, b;
    while (!Q.empty()) {
        a = Q.top(); Q.pop();
        if (Q.empty()) break;
        b = Q.top(); Q.pop();
        Q.push(a * b);
    }
    return a;
}

inline poly f0(void) {
    poly f; f.push_back(1);
    return f;
}
inline poly f1(int w) {
    poly f; f.push_back(0); f.push_back(w);
    return f;
}

vector<qua> lt;
inline qua divAndConq(int l, int r) {
    if (l == r) return lt[l];
    int mid = (l + r) >> 1;
    return divAndConq(l, mid) * divAndConq(mid + 1, r);
}

inline void watch(poly x) {
    cerr << "{ ";
    for (int u: x) cerr << u << ", ";
    cerr << "}" << endl;
}
inline void watch(arr x){
    watch(x.v0); watch(x.v1);
}

inline void fuck(int v) {
    lt.clear();
    for (int u = v; u; u = son[u]) {
        for (int to: G[u])
            if (to != son[u] && to != fa[u])
                Q.push(arr(f[to].v0 + f[to].v1, f[to].v0));
        arr cur;
        if (!Q.empty()) cur = merge();
        //watch(cur);
        if (cur.v0.empty()) cur.v0 = f0();
        if (cur.v1.empty()) cur.v1 = f1(w[u]);
        else cur.v1 = cur.v1 * f1(w[u]);
        //watch(cur);
        lt.push_back(qua{cur.v0, blank, blank, cur.v1});
    }
    qua res = divAndConq(0, lt.size() - 1);
    f[v] = arr(res.v00 + res.v01, res.v10 + res.v11);
    //watch(f[v]);
}

int main(void) {
    freopen("1.in", "r", stdin);
    freopen("1.out", "w", stdout);
    FNT::pre(1 << 18);
    read(n); read(m);
    for (int i = 1; i <= n; i++) read(w[i]);
    for (int i = 1; i < n; i++) {
        int x, y;
        read(x); read(y);
        addEdge(x, y);
    }
    dfs1(1);
    for (int i = n; i >= 1; i--) {
        int u = erp[i];
        if (son[fa[u]] != u) fuck(u);
    }
    int ans = 0;
    if (f[1].v0.size() >= m) ans = sum(ans, f[1].v0[m]);
    if (f[1].v1.size() >= m) ans = sum(ans, f[1].v1[m]);
    cout << ans << endl;
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章