【題目描述】
給出N,以及N各數,求$ \sum {l = 1} ^ N \sum{r = 1}^N \sum_{i = l}^r \sum_{j = i +1, a[j] < a[i]}^r a[i] *a[j] $
輸出答案對$10^{12} +7 $取模的結果
40%:\(N \leq 50\)
60%:\(N \leq100\)
80%:\(N\leq 1000\)
90%:\(1 \leq a_i \leq 10^5\)
100%:\(N\leq4 * 10^4,1 \leq a_i \leq 10^{12}\)
Solution
首先觀察式子發現他讓我們求序列所有子區間的嚴格逆序對乘積和
最簡單的暴力枚舉,\(N^4\)
手推一下發現,對於每一個數,都要與其後面的比他小的數相乘多次,這個次數就是包含他們的區間個數。那麼顯然有i * (n - j +1)個區間,其中 i 是大的數的位置, j 是小的數的位置。
這樣是\(N^2\)的,可以得到80pt
#include <iostream>
#include <cstdio>
using namespace std;
inline long long read() {
long long x = 0; int f = 0; char c = getchar();
while (c < '0' || c > '9') f |= c == '-', c = getchar();
while (c >= '0' && c <= '9') x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
return f ? -x : x;
}
const long long mod = 1e12 + 7;
int n;
long long a[40004], ans;
int main() {
freopen("multiplication.in", "r", stdin);
freopen("multiplication.out", "w", stdout);
n = read();
for (int i = 1; i <= n; ++i) a[i] = read();
for (int i = 1; i <= n; ++i) {
long long sum = 0;
for (int j = i + 1; j <= n; ++j) {
if (a[j] < a[i]) ans = (ans + a[i] * a[j] % mod * (n - j + 1) % mod * i % mod) % mod;
}
}
printf("%lld\n", ans);
return 0;
}
然後我們發現對於每一個大的數,它的貢獻是a[i] * i,每一個小的數的貢獻是a[j] * (n - j + 1),他們兩兩沒有關係。於是可以邊加入邊計算。
搞一棵權值線段樹,每加入一個點,首先計算比他先插入的且比他大的數的貢獻,然後再乘以(n - i + 1),再將a[i] * i 加入線段樹就好了。
但會發現這麼搞只有90pt,因爲模數太大了,一乘就暴long long 了,所以再加一個快速龜速乘就好了
#include <iostream>
#include <cstdio>
using namespace std;
inline long long read() {
long long x = 0; int f = 0; char c = getchar();
while (c < '0' || c > '9') f |= c == '-', c = getchar();
while (c >= '0' && c <= '9') x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
return f ? -x : x;
}
struct szh {
int l, r;
long long sum;
szh() { l = 0; r = 0; sum = 0; }
}a[1600004];
#define mid ((l + r) >> 1)
const long long mod = 1e12 + 7;
int n, cnt = 1;
inline void pushup(int u) {
a[u].sum = (a[a[u].l].sum + a[a[u].r].sum) % mod;
}
inline long long get_sum(long long l, long long r, long long u, long long L, long long R) {
if (!u) return 0;
if (L <= l && r <= R){
return a[u].sum;
}
long long ans = 0;
if (L <= mid) ans = get_sum(l, mid, a[u].l, L, R);
if (mid < R) ans = (ans + get_sum(mid + 1, r, a[u].r, L ,R)) % mod;
return ans;
}
inline void add(long long l, long long r, long long u, long long p, long long v) {
if (l == r) {
a[u].sum = (a[u].sum + v) % mod; return;
}
if (p <= mid) {
if (!a[u].l) a[u].l = ++cnt;
add(l, mid, a[u].l, p, v);
}
else {
if (!a[u].r) a[u].r = ++cnt;
add(mid + 1, r, a[u].r, p, v);
}
pushup(u);
}
inline long long mul(long long x, long long b) {
long long ans = 0;
while (b) {
if (b & 1) ans = (ans + x) % mod;
x <<= 1; x %= mod;
b >>= 1;
}
return ans;
}
int main() {
freopen("multiplication.in", "r", stdin);
freopen("multiplication.out", "w", stdout);
n = read();
long long ans = 0;
for (int i = 1; i <= n; ++i) {
long long x = read();
long long xx = x;
x = mul(x, get_sum(1, mod, 1, x + 1, mod));
x = x * (n - i + 1) % mod;
ans = (ans + x) % mod;
add(1, mod, 1, xx, 1ll * xx * i % mod);
}
printf("%lld\n", ans);
return 0;
}