CF1042E Vasya and Magic Matrix

CF1042E Vasya and Magic Matrix


題意

一個nm列的矩陣,每個位置有權值ai,ja_{i,j}

給定一個出發點,每次可以等概率的移動到一個權值小於當前點權值的點,同時得分加上兩個點之間歐幾里得距離的平方(歐幾里得距離:(x1x2)2+(y1y2)2\sqrt{(x_1-x_2)^2+(y_1-y_2)^2},問得分的期望


思路

按照計算期望的一般思路

我們可以考慮先計算小的點的期望,在用小的點的期望計算大的點的期望

我們先從小到大排序

轉移方程

dp[i]=jaj<aidp[j]+(xixj)2+(yiyj)2dp[i]= \sum_{j}^{a_{j}<a_{i}}{dp[j]+(x_{i}-x_{j})^2+(y_{i}-y_{j})^2}

前綴和優化

  • (xixj)2=xi2+2xixj+xj2(x_{i}-x_{j})^2 =x_{i}^2 +2*x_{i}*x_{j} +x_{j}^2

    所以我們只要知道xi\sum{x_{i}}xi2\sum {x_{i}^2}就可以O(1)轉移

  • 我們維護前綴和sum=dp[j]sum=\sum{dp[j]}sx=x[j]sx=\sum{x[j]}x2=x[j]2x2=\sum{x[j]^2}sy=y[j]sy=\sum{y[j]}y2=y[j]2y2=\sum{y[j]^2}

  • dp[i]=sum+sx+sy2x[i]sx2y[i]sy+(x[i]2+y[i]2)(i)dp[i]=sum+sx+sy-2*x[i]*sx-2*y[i]*sy+(x[i]^2+y[i]^2)*(小於i的點數)

ans = (sum + x2 + y2) % mod;
LL inv = quickpow(p - 1, mod - 2);
dp[j] = ((ans - sx * 2 * node[j].x - sy * 2 * node[j].y) % mod + mod) % mod;
dp[j] = (dp[j] + LL(p - 1) * node[j].x * node[j].x % mod + LL(p - 1) * node[j].y * node[j].y % mod) % mod;
dp[j] = dp[j] * inv % mod;

代碼

#include <bits/stdc++.h>
using namespace std;
struct Node {
    int data;
    int x;
    int y;
    friend bool operator<(const Node& a, const Node& b)
    {
        return a.data < b.data;
    }
};
const int maxn = 1001 * 1001 * 5;
Node node[maxn];
typedef long long LL;
const LL mod = 998244353;
LL quickpow(LL m, LL p)
{
    LL res = 1;
    while (p) {
        if (p & 1)
            res = res * m % mod;
        m = m * m % mod;
        p >>= 1;
    }
    return res;
}
LL dp[maxn];
int main()
{
    int n, m, k = 0;
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) {
            scanf("%d", &node[++k].data);
            node[k].x = i;
            node[k].y = j;
        }
    }
    int ex, ey;
    scanf("%d%d", &ex, &ey);
    sort(node + 1, node + 1 + k);
    bool flag = false;
    LL sx, x2, sy, y2, sum, ans;
    sum = sx = sy = x2 = y2 = 0;
    for (int i = 1; i <= k;) {
        int p = i;
        while (++i <= k && node[i].data == node[p].data)
            ;
        ans = (sum + x2 + y2) % mod;
        LL inv = quickpow(p - 1, mod - 2);
        for (int j = p; j < i; j++) {
            dp[j] = ((ans - sx * 2 * node[j].x - sy * 2 * node[j].y) % mod + mod) % mod;
            dp[j] = (dp[j] + LL(p - 1) * node[j].x * node[j].x % mod + LL(p - 1) * node[j].y * node[j].y % mod) % mod;
            dp[j] = dp[j] * inv % mod;
            if (node[j].x == ex && node[j].y == ey) {
                flag = true;
                ans = dp[j];
                break;
            }
        }
        if (flag)
            break;
        for (int j = p; j < i; j++)
            sum = (sum + dp[j]) % mod,
            sx += node[j].x,
            x2 += LL(node[j].x) * node[j].x,
            sy += node[j].y,
            y2 += LL(node[j].y) * node[j].y;
    }
    printf("%lld\n", ans);
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章