題目鏈接
題意
- 給你一顆樹,求有多少有序點對(u,v)使得路徑上所有點權最大值減去最小值不大於D
題解
- 挺裸的點分治
- 注意題意是有序點對,而不是加上每個點和自己構成的路徑
- 每次找出當前聯通塊的重心,求出所有通過重心的方案數,去掉同一棵樹的兩個點對應的方案,求的時候可以按照最大值升序排序,然後用樹狀數組去查一下區間和就行了(類似樹狀數組求逆序對的做法),當然也可以給最小值升序排序,然後前綴二分查詢當前最大值-D,這種方法慎用
複雜度
- O(n(logn)2)
代碼
#pragma comment(linker, "/STACK:102400000,102400000")
#include<cstdio>
#include<iostream>
#include<algorithm>
using namespace std;
const int maxn=2e5+10;
const int maxm=3e6+10;
int n,k,a[maxn],c[maxn],d[maxn],m;
namespace bit{
int s[maxn];
int lowbit(int x) {return x&(-x);}
void add(int id,int val){
for(int i=id;i<=m;i+=lowbit(i)) s[i]+=val;
}
int query(int id){
int ans=0;
for(int i=id;i>=1;i-=lowbit(i)) ans+=s[i];
return ans;
}
int sum(int l,int r) {return query(r)-query(l-1);}
void init(int n) {for(int i=0;i<=n;i++) s[i]=0;}
}
using namespace bit;
namespace point_divide_and_conquer{
int tot,head[maxn],siz[maxn],root,min_son,num;
bool vis[maxn];
struct ed{int v,w,next;}edge[2*maxn];
struct da{
int minn,maxx;
da(int a=0,int b=0) {minn=a;maxx=b;}
friend bool operator<(const da &a,const da &b) {
return a.maxx<b.maxx;
}
}data[maxn];
inline void init(int n) {
tot=0;
for(int i=1;i<=n;i++) head[i]=0,vis[i]=false;
}
inline void add_edge(int u,int v,int w) {
edge[++tot]=ed{v,w,head[u]};
head[u]=tot;
}
inline void dfs_size(int cur,int fa) {
siz[cur]=1;
for(int i=head[cur];i;i=edge[i].next) {
if(edge[i].v!=fa && !vis[edge[i].v]) {
dfs_size(edge[i].v,cur);
siz[cur]+=siz[edge[i].v];
}
}
}
inline void dfs_root(int cur,int fa,int all) {
int max_son=all-siz[cur];
for(int i=head[cur];i;i=edge[i].next) {
if(edge[i].v!=fa && !vis[edge[i].v]) {
max_son=max(max_son,siz[edge[i].v]);
dfs_root(edge[i].v,cur,all);
}
}
if(max_son<min_son) min_son=max_son,root=cur;
}
inline void dfs_roote(int cur,int fa,int minn,int maxx) {
if(maxx-minn<=k)data[++num]=da{minn,maxx};
for(int i=head[cur];i;i=edge[i].next) {
if(edge[i].v!=fa && !vis[edge[i].v]) {
dfs_roote(edge[i].v,cur,min(a[edge[i].v],minn),max(maxx,a[edge[i].v]));
}
}
}
inline long long calc(int cur,int fa,int minn,int maxx) {
num=0;
dfs_roote(cur,fa,minn,maxx);
sort(data+1,data+num+1);
for(int i=1;i<=num;i++) d[i]=data[i].minn;
sort(d+1,d+num+1);
m=unique(d+1,d+num+1)-d-1;
bit::init(m);
long long ans=0;
for(int i=1;i<=num;i++) {
int loc=lower_bound(d+1,d+m+1,data[i].minn)-d;
int pos=lower_bound(d+1,d+m+1,data[i].maxx-k)-d;
if(pos<=m) ans+=bit::sum(pos,m);
bit::add(loc,1);
}
return ans;
}
inline long long solve(int cur) {
min_son=0x3f3f3f3f;
dfs_size(cur,0);
dfs_root(cur,0,siz[cur]);
vis[root]=true;
long long ans=calc(root,0,a[root],a[root]);
for(int i=head[root];i;i=edge[i].next) {
if(!vis[edge[i].v]) {
ans-=calc(edge[i].v,0,min(a[root],a[edge[i].v]),max(a[root],a[edge[i].v]));
}
}
for(int i=head[root];i;i=edge[i].next) if(!vis[edge[i].v]) ans+=solve(edge[i].v);
return ans;
}
}
using namespace point_divide_and_conquer;
int main() {
int t;scanf("%d",&t);
while(t--) {
scanf("%d %d",&n,&k);
point_divide_and_conquer::init(n);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
for(int i=1,u,v;i<n;i++) {
scanf("%d %d",&u,&v);
add_edge(u,v,1);
add_edge(v,u,1);
}
printf("%lld\n",2*solve(1));
}
}