代码:
```cpp
#include <iostream>
#include <cstdio>
#define int long long
using namespace std;
const int N = 100005;
int n, m, r, mod;
int num, w[N], fa[N], depth[N], sz[N], son[N], tid[N], nw[N], top[N];
int cnt, head[N];
int res;
struct edge {
int v, next;
} e[N << 1];
struct tree {
int num, lz, l, r;
} t[N << 2];
void add(int u, int v) {
e[++cnt].v = v;
e[cnt].next = head[u];
head[u] = cnt;
}
void dfs1(int p, int f, int dep) {
//cout << p << ' ' << f << ' ' << dep << " \n";
fa[p] = f;
depth[p] = dep;
sz[p] = 1;
int maxs = -1;
for (int i = head[p]; i; i = e[i].next) {
int v = e[i].v;
if (v == f)
continue;
dfs1(v, p, dep + 1);
sz[p] += sz[v];
if (maxs < sz[v]) {
son[p] = v;
maxs = sz[v];
}
}
}
void dfs2(int p, int tp) {
tid[p] = ++num;
nw[num] = w[p] % mod;
top[p] = tp;
//cout << p << ' ' << tid[p] << ' ' << sz[p] << '\n';
if (!son[p])
return ;
dfs2(son[p], tp);
for (int i = head[p]; i; i = e[i].next) {
int v = e[i].v;
if (v != son[p] && v != fa[p])
dfs2(v, v);
}
}
void p_d(int p) {
if (!t[p].lz)
return ;
int l = p << 1, r = (p << 1) | 1, lz = t[p].lz;
t[l].num += lz * (t[l].r - t[l].l + 1);
t[r].num += lz * (t[r].r - t[r].l + 1);
t[l].num %= mod;
t[r].num %= mod;
t[l].lz += lz;
t[r].lz += lz;
t[l].lz %= mod;
t[r].lz %= mod;
t[p].lz = 0;
}
void build(int p, int l, int r) {
t[p].l = l;
t[p].r = r;
if (l == r) {
t[p].num = nw[l];
return ;
}
int mid = (l + r) >> 1;
build(p << 1, l, mid);
build((p << 1) | 1, mid + 1, r);
t[p].num = (t[p << 1].num + t[(p << 1) | 1].num) % mod;
}
void val(int p, int l, int r) {
if (l <= t[p].l && t[p].r <= r) {
res = (res + t[p].num) % mod;
return ;
}
if (t[p].lz)
p_d(p);
int mid = (t[p].l + t[p].r) >> 1;
if (l <= mid)
val(p << 1, l, r);
if (r > mid)
val((p << 1) | 1, l, r);
}
void upd(int p, int l, int r, int k) {
//cout << p << ' ' << l << ' ' << r << ' ' << k << '\n';
if (l <= t[p].l && t[p].r <= r) {
t[p].lz = (t[p].lz + k) % mod;
t[p].num = (t[p].num + k * (t[p].r - t[p].l + 1)) % mod;
return ;
}
if (t[p].lz)
p_d(p);
int mid = (t[p].l + t[p].r) >> 1;
if (l <= mid)
upd(p << 1, l, r, k);
if (r > mid)
upd((p << 1) | 1, l, r, k);
t[p].num = (t[p << 1].num + t[(p << 1) | 1].num) % mod;
}
int qroad(int x, int y) {
int ans = 0;
while (top[x] != top[y]) {
if (depth[x] < depth[y])
swap(x, y);
res = 0;
val(1, tid[top[x]], tid[x]);
ans = (ans + res) % mod;
x = fa[top[x]];
}
if (depth[x] > depth[y])
swap(x, y);
res = 0;
val(1, tid[x], tid[y]);
ans = (ans + res) % mod;
return ans;
}
void uroad(int x, int y, int k) {
while (top[x] != top[y]) {
if (depth[x] < depth[y])
swap(x, y);
upd(1, tid[top[x]], tid[x], k);
x = fa[top[x]];
}
if (depth[x] > depth[y])
swap(x, y);
upd(1, tid[x], tid[y], k);
}
int qtree(int x) {
res = 0;
val(1, tid[x], tid[x] + sz[x] - 1);
return res;
}
void utree(int x, int k) {
//cout << "1!!\n";
upd(1, tid[x], tid[x] + sz[x] - 1, k);
//cout << "2!!\n";
}
signed main() {
// freopen("input.in", "r", stdin);
// freopen("ans.txt", "w", stdout);
scanf("%lld%lld%lld%lld", &n, &m, &r, &mod);
int u, v, op, x, y, k;
// cout << "1!!\n";
for (int i = 1; i <= n; i++)
scanf("%lld", &w[i]);
// cout << "2!!\n";
for (int i = 1; i < n; i++) {
scanf("%lld%lld", &u, &v);
add(u, v);
add(v, u);
}
// cout << "3!!\n";
dfs1(r, 0, 1);
// cout << "4!!\n";
dfs2(r, r);
//cout << "5!!\n";
build(1, 1, n);
//cout << "6!!\n";
while (m--) {
scanf("%lld", &op);
if (op == 1) {
scanf("%lld%lld%lld", &x, &y, &k);
uroad(x, y, k);
} else if (op == 2) {
scanf("%lld%lld", &x, &y);
printf("%lld\n", qroad(x, y));
} else if (op == 3) {
scanf("%lld%lld", &x, &k);
utree(x, k);
} else if (op == 4) {
scanf("%lld", &x);
printf("%lld\n", qtree(x));
}
}
//cout << "7!!\n";
return 0;
}
```
by farfar @ 2022-07-03 14:19:19
`if (depth[x] < depth[y])`
这里是`if (depth[top[x]] < depth[top[y]])`(别问我为什么能看出来,亲身经历
by ningago @ 2022-07-03 14:23:39
@[farfar](/user/378951)
by ningago @ 2022-07-03 14:23:49
@[ningago](/user/371968) 感谢大佬!
by farfar @ 2022-07-03 15:17:37