There are n points on a coordinate axis OX. The i-th point is located at the integer point xi and has a speed vi. It is guaranteed that no two points occupy the same coordinate. All n points move with the constant speed, the coordinate of the i-th point at the moment t (t can be non-integer) is calculated as xi+t⋅vi.
Consider two points i and j. Let d(i,j) be the minimum possible distance between these two points over any possible moments of time (even non-integer). It means that if two points i and j coincide at some moment, the value d(i,j) will be 0.
Your task is to calculate the value ∑1≤i<j≤n d(i,j) (the sum of minimum distances over all pairs of points).
Input
The first line of the input contains one integer n (2≤n≤2⋅105) — the number of points.
The second line of the input contains n integers x1,x2,…,xn (1≤xi≤108), where xi is the initial coordinate of the i-th point. It is guaranteed that all xi are distinct.
The third line of the input contains n integers v1,v2,…,vn (−108≤vi≤108), where vi is the speed of the i-th point.
Output
Print one integer — the value ∑1≤i<j≤n d(i,j) (the sum of minimum distances over all pairs of points).
Examples
inputCopy
3
1 3 2
-100 2 3
outputCopy
3
inputCopy
5
2 1 4 3 5
2 2 2 3 4
outputCopy
19
inputCopy
2
2 1
-3 0
outputCopy
0
題意:
在一個座標軸上,給你n個點,每個點都有兩個屬性一個是xi(代表位置),一個是vi代表每秒速度。 每個點移動 xi+t*vi,d(i,j)表示i到j的最短距離。’
問這個公式的最小值是多少。
解析:
xi<xj 且 vi>vj 那麼d(i,j)一定等於0,因爲在某一時刻可以相遇 這樣的情況對答案的貢獻爲0
xi<xj 且 vi<=vj 那麼d(i,j)=abs(xi-xj)。因爲不管什麼時候i永遠無法遇到j 這樣的情況對答案的貢獻爲 abs(xi-xj);
那麼現在我假設 位於xi左邊的有 x0,x1,x2,x3,x4…xn 且速度都小於 vi
對答案的貢獻爲 (xi-x0)+(xi-x1)+(xi-x2)+(xi-x3)+(xi-x4)+…+(xi-xn)
整合一下就是 nxi-(x0+x1+x2+x3+x4+…+xn)
這具備了前綴和。所以我們用樹狀數組維護。
c[0][x]維護x左邊數出現的個數
c[1][x]維護x左邊數的總和
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=2e5+100;
ll c[2][N];
ll b[N];
int n;
struct node
{
ll x,v;
}a[N];
bool cmp(const node &a,const node &b)
{
return a.x<b.x;
}
int lowbit(int x)
{
return x&(-x);
}
ll sum(int x,int k)
{
ll res=0;
while(k)
{
res=res+c[x][k];
k-=lowbit(k);
}
return res;
}
ll add(int x,int val)
{
while(x<=N)
{
c[0][x]++;
c[1][x]+=val;
x+=lowbit(x);
}
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%lld",&a[i].x);
for(int i=1;i<=n;i++)
{
scanf("%lld",&a[i].v);
b[i]=a[i].v;
}
sort(a+1,a+1+n,cmp);
sort(b+1,b+1+n);
ll m=unique(b+1,b+1+n)-b-1;
ll ans=0;
// for(int i=1;i<=n;i++) cout<<b[i]<<endl;
for(int i=1;i<=n;i++)
{
int x=lower_bound(b+1,b+1+m,a[i].v)-b;
// cout<<x<<endl;
ans=ans+(a[i].x*sum(0,x))-sum(1,x);
add(x,a[i].x);
}
cout<<ans<<endl;
}