乘積求和

【題目描述】

給出N,以及N各數,求$ \sum {l = 1} ^ N \sum{r = 1}^N \sum_{i = l}^r \sum_{j = i +1, a[j] < a[i]}^r a[i] *a[j] $

輸出答案對$10^{12} +7 $取模的結果

40%:\(N \leq 50\)

60%:\(N \leq100\)

80%:\(N\leq 1000\)

90%:\(1 \leq a_i \leq 10^5\)

100%:\(N\leq4 * 10^4,1 \leq a_i \leq 10^{12}\)

Solution

首先觀察式子發現他讓我們求序列所有子區間的嚴格逆序對乘積和

最簡單的暴力枚舉,\(N^4\)

手推一下發現,對於每一個數,都要與其後面的比他小的數相乘多次,這個次數就是包含他們的區間個數。那麼顯然有i * (n - j +1)個區間,其中 i 是大的數的位置, j 是小的數的位置。

這樣是\(N^2\)的,可以得到80pt

#include <iostream>
#include <cstdio>
using namespace std;
inline long long read() {
  long long x = 0; int f = 0; char c = getchar();
  while (c < '0' || c > '9') f |= c == '-', c = getchar();
  while (c >= '0' && c <= '9') x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
  return f ? -x : x;
}

const long long mod = 1e12 + 7;
int n;
long long a[40004], ans;
int main() {
  freopen("multiplication.in", "r", stdin);
  freopen("multiplication.out", "w", stdout);
  n = read();
  for (int i = 1; i <= n; ++i) a[i] = read();
  for (int i = 1; i <= n; ++i) {
    long long sum = 0;
    for (int j = i + 1; j <= n; ++j) {
      if (a[j] < a[i]) ans = (ans + a[i] * a[j] % mod * (n - j + 1) % mod * i % mod) % mod;
    }
  }
  printf("%lld\n", ans);
  return 0;
}

然後我們發現對於每一個大的數,它的貢獻是a[i] * i,每一個小的數的貢獻是a[j] * (n - j + 1),他們兩兩沒有關係。於是可以邊加入邊計算。

搞一棵權值線段樹,每加入一個點,首先計算比他先插入的且比他大的數的貢獻,然後再乘以(n - i + 1),再將a[i] * i 加入線段樹就好了。

但會發現這麼搞只有90pt,因爲模數太大了,一乘就暴long long 了,所以再加一個快速龜速乘就好了

#include <iostream>
#include <cstdio>
using namespace std;
inline long long read() {
  long long x = 0; int f = 0; char c = getchar();
  while (c < '0' || c > '9') f |= c == '-', c = getchar();
  while (c >= '0' && c <= '9') x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
  return f ? -x : x;
}

struct szh {
  int l, r;
  long long sum;
  szh() { l = 0; r = 0; sum = 0; }
}a[1600004];
#define mid ((l + r) >> 1)
const long long mod = 1e12 + 7;
int n, cnt = 1;
inline void pushup(int u) {
  a[u].sum = (a[a[u].l].sum + a[a[u].r].sum) % mod;
}
inline long long get_sum(long long l, long long r, long long u, long long L, long long R) {
  if (!u) return 0;
  if (L <= l && r <= R){
    return a[u].sum;
  }
  long long ans = 0;
  if (L <= mid) ans = get_sum(l, mid, a[u].l, L, R);
  if (mid < R) ans = (ans + get_sum(mid + 1, r, a[u].r, L ,R)) % mod;
  return ans;
}
inline void add(long long l, long long r, long long u, long long p, long long v) {
  if (l == r) {
    a[u].sum = (a[u].sum + v) % mod; return;
  }
  if (p <= mid) {
    if (!a[u].l) a[u].l = ++cnt;
    add(l, mid, a[u].l, p, v);
  }
  else {
    if (!a[u].r) a[u].r = ++cnt;
    add(mid + 1, r, a[u].r, p, v);
  }
  pushup(u);
}
inline long long mul(long long x, long long b) {
  long long ans = 0;
  while (b) {
    if (b & 1) ans = (ans + x) % mod;
    x <<= 1; x %= mod;
    b >>= 1;
  }
  return ans;
}
int main() {
  freopen("multiplication.in", "r", stdin);
  freopen("multiplication.out", "w", stdout);
  n = read();
  long long ans = 0;
  for (int i = 1; i <= n; ++i) {
    long long x = read();
    long long xx = x;
    x = mul(x, get_sum(1, mod, 1, x + 1, mod));
    x = x * (n - i + 1) % mod;
    ans = (ans + x) % mod;
    add(1, mod, 1, xx, 1ll * xx * i % mod);
  }
  printf("%lld\n", ans);
  return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章