題意:給定一棵樹,邊上有權,每個點有一個顏色A,q次詢問,每次詢問z,x,y表示顏色在[x,y]的所有點到點z的距離之和。
數據範圍:滿足 n<=150000,Q<=200000。對於所有數據,滿足 A<=10^9
感覺這題非常難啊完全不會啊看上去是個動態樹分治可是想了半天啥都沒想到。。。。。
考慮dis(u,v)=dep(u)+dep(v)-2*dep(lca(u,v))(dis(u,v)表示u,v的距離,dep(x)表示根節點到x的距離),假如沒有顏色的限制,那麼問題就變成所有點到x的距離,假如我們從上面的式子進行考慮,u是確定的,v(就是1到n所有的節點)。dep(u)可以通過dfs求出,dep(v)(其實就是所有點到根的距離)可以通過dfs之後累加,關鍵點就是求x和v(其中1<=v<=n)的lca到根的距離和。
具體做法就是我們進行樹鏈剖分,用線段樹維護,線段樹葉子節點維護的是對應的點和它的父親這條路徑被經過的距離和,非葉子節點維護的是區間和,那麼對於一個節點x(1<=x<=n),我們一路跳到根節點,期間把每個點都加上這個點和這個點父親的距離,一直加到根,區間加可以用線段樹來實現(每次對x到top(x)進行區間加),具體做法的話假設我們做到點x,它的樹鏈剖分序是t,則sum[t]=dis(x,fa(x)),然後我們需要維護sum的前綴和,這樣在線段樹區間加的時候,若是當前區間被目標區間完全包含則對對應節點加上sum[q]-sum[p-1]假設[p,q]是當前遞歸到的節點。查詢x點和其他所有點的lca到根的距離和的時候,我們還是一路跳,把樹上每個點的權值加上去,該過程可以用線段樹優化。
考慮爲什麼這樣做是對的,對於一個點x,它對它和它的祖先(設爲y)的lca必然就是它的祖先本身,然後這樣的貢獻就是相當於對於每個y都加上dis(1,y),即每個y和它父親這條路徑要多經過一次。然後最後查詢相當於是查一個點到根的路徑上每條邊被經過了的距離和。
因爲顏色很大,所以離散化之後用可持久化線段樹。
可持久化線段樹區間加的下放標記要分2種類型,難寫速度慢內存大。以下介紹一種簡便方法:
設當前區間[p,q],目標區間[l,r],則我們在除了完全包含的區間外的所有區間都加上sum(min(r,q))-sum(max(l,p)-1)(這步相當於維護了一個區間和),否則我們另外開數組sum1對當前節點加1;查詢的時候,每次走到一個被部分包含的節點就加上sum1*sum(min(r,q))-sum(max(l,p)-1),若走到完全包含的節點則再加上sum。證明請讀者自行思考。
這傻逼被這題卡了一天。。
代碼:
#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int MAXN = 150005;
const int NOD = 10000005;
int first[MAXN], next[MAXN << 1], go[MAXN << 1], way[MAXN << 1], t;
int dis[MAXN], dep[MAXN], fa[MAXN], size[MAXN], top[MAXN], son[MAXN], pos[MAXN];
int n, m, i, j, k, x, y, z, A, b[MAXN], len, p, f[MAXN];
int sum1[NOD], lc[NOD], rc[NOD], root[MAXN];
long long ans, tot, num[MAXN], sum[NOD];
struct sb{
int point, color;
};
sb a[MAXN];
inline bool rule(const sb &a, const sb &b)
{
return a.color < b.color;
}
inline int get()
{
char c;
while ((c = getchar()) < 48 || c > 57);
int res = c - 48;
while ((c = getchar()) >= 48 && c <= 57)
res = res * 10 + c - 48;
return res;
}
inline void add(const int &x, const int &y, const int &z)
{
next[++t] = first[x]; first[x] = t; go[t] = y; way[t] = z;
next[++t] = first[y]; first[y] = t; go[t] = x; way[t] = z;
}
inline void dfs(int now)
{
size[now] = 1;
int son1 = 0, son2 = 0;
for(int i = first[now]; i; i = next[i])
if (fa[now] != go[i])
{
fa[go[i]] = now;
dep[go[i]] = dep[now] + way[i];
dis[go[i]] = way[i] + dis[now];
dfs(go[i]);
size[now] += size[go[i]];
if (size[go[i]] > son1) son1 = size[go[i]], son2 = go[i];
}
son[now] = son2;
}
inline void dfs1(int now)
{
pos[now] = ++t;
dis[t] = dep[now] - dep[fa[now]];
if (son[now])
{
top[son[now]] = top[now];
dfs1(son[now]);
}
for(int i = first[now]; i; i = next[i])
if (!pos[go[i]])
{
top[go[i]] = go[i];
dfs1(go[i]);
}
}
inline void insert(int &x, int y, int p, int q, int l, int r)
{
x = ++t;
lc[x] = lc[y];
rc[x] = rc[y];
sum[x] = sum[y];
sum1[x] = sum1[y];
if (p >= l && q <= r)
{
sum1[x] ++;
return;
}
sum[x] += dis[min(r, q)] - dis[max(l, p) - 1];
int mid = (p + q) >> 1;
if (mid >= l) insert(lc[x], lc[y], p, mid, l, r);
if (mid < r) insert(rc[x], rc[y], mid + 1, q, l, r);
}
inline void find(int k, int p, int q, int l, int r)
{
if (sum1[k]) tot += sum1[k] * (long long)(dis[min(r, q)] - dis[max(l, p) - 1]);
if (p >= l && q <= r)
{
tot += sum[k];
return;
}
int mid = (p + q) >> 1;
if (mid >= l) find(lc[k], p, mid, l, r);
if (mid < r) find(rc[k], mid + 1, q, l, r);
}
inline long long solve(int x, int y)
{
tot = 0;
while (x)
{
find(root[y], 1, n, pos[top[x]], pos[x]);
x = fa[top[x]];
}
return tot;
}
int main()
{
cin >> n >> m >> A;
for(i = 1; i <= n; i ++)
a[i].color = get(), a[i].point = i, b[++len] = a[i].color;
sort(a + 1, a + 1 + n, rule);
sort(b + 1, b + 1 + len);
len = unique(b + 1, b + 1 + len) - 1 - b;
for(i = 1; i < n; i ++)
{
x = get(); y = get(); z = get();
add(x, y, z);
}
dfs(1);
t = 0; top[1] = 1;
dfs1(1);
t = 0;
a[0].color = -1;
for(i = 2; i <= n; i ++)
dis[i] += dis[i - 1];
for(i = 1; i <= n; i ++)
{
if (a[i].color != a[i - 1].color) p ++, root[p] = root[p - 1];
f[p] ++;
num[p] += dep[a[i].point];
x = a[i].point;
while (x)
{
insert(root[p], root[p], 1, n, pos[top[x]], pos[x]);
x = fa[top[x]];
}
}
for(i = 2; i <= len; i ++)
num[i] += num[i - 1], f[i] += f[i - 1];
while (m --)
{
z = get(); x = get(); y = get();
x = (x + ans) % A;
y = (y + ans) % A;
if (x > y) swap(x, y);
x = lower_bound(b + 1, b + 1 + len, x) - b;
y = upper_bound(b + 1, b + 1 + len, y) - b;
y --;
ans = (f[y] - f[x - 1]) * (long long)dep[z] - 2 * (solve(z, y) - solve(z, x - 1)) + num[y] - num[x - 1];
printf("%lld\n", ans);
}
}