20pts,AC#1#3求助

P3384 【模板】重链剖分/树链剖分

代码: ```cpp #include <iostream> #include <cstdio> #define int long long using namespace std; const int N = 100005; int n, m, r, mod; int num, w[N], fa[N], depth[N], sz[N], son[N], tid[N], nw[N], top[N]; int cnt, head[N]; int res; struct edge { int v, next; } e[N << 1]; struct tree { int num, lz, l, r; } t[N << 2]; void add(int u, int v) { e[++cnt].v = v; e[cnt].next = head[u]; head[u] = cnt; } void dfs1(int p, int f, int dep) { //cout << p << ' ' << f << ' ' << dep << " \n"; fa[p] = f; depth[p] = dep; sz[p] = 1; int maxs = -1; for (int i = head[p]; i; i = e[i].next) { int v = e[i].v; if (v == f) continue; dfs1(v, p, dep + 1); sz[p] += sz[v]; if (maxs < sz[v]) { son[p] = v; maxs = sz[v]; } } } void dfs2(int p, int tp) { tid[p] = ++num; nw[num] = w[p] % mod; top[p] = tp; //cout << p << ' ' << tid[p] << ' ' << sz[p] << '\n'; if (!son[p]) return ; dfs2(son[p], tp); for (int i = head[p]; i; i = e[i].next) { int v = e[i].v; if (v != son[p] && v != fa[p]) dfs2(v, v); } } void p_d(int p) { if (!t[p].lz) return ; int l = p << 1, r = (p << 1) | 1, lz = t[p].lz; t[l].num += lz * (t[l].r - t[l].l + 1); t[r].num += lz * (t[r].r - t[r].l + 1); t[l].num %= mod; t[r].num %= mod; t[l].lz += lz; t[r].lz += lz; t[l].lz %= mod; t[r].lz %= mod; t[p].lz = 0; } void build(int p, int l, int r) { t[p].l = l; t[p].r = r; if (l == r) { t[p].num = nw[l]; return ; } int mid = (l + r) >> 1; build(p << 1, l, mid); build((p << 1) | 1, mid + 1, r); t[p].num = (t[p << 1].num + t[(p << 1) | 1].num) % mod; } void val(int p, int l, int r) { if (l <= t[p].l && t[p].r <= r) { res = (res + t[p].num) % mod; return ; } if (t[p].lz) p_d(p); int mid = (t[p].l + t[p].r) >> 1; if (l <= mid) val(p << 1, l, r); if (r > mid) val((p << 1) | 1, l, r); } void upd(int p, int l, int r, int k) { //cout << p << ' ' << l << ' ' << r << ' ' << k << '\n'; if (l <= t[p].l && t[p].r <= r) { t[p].lz = (t[p].lz + k) % mod; t[p].num = (t[p].num + k * (t[p].r - t[p].l + 1)) % mod; return ; } if (t[p].lz) p_d(p); int mid = (t[p].l + t[p].r) >> 1; if (l <= mid) upd(p << 1, l, r, k); if (r > mid) upd((p << 1) | 1, l, r, k); t[p].num = (t[p << 1].num + t[(p << 1) | 1].num) % mod; } int qroad(int x, int y) { int ans = 0; while (top[x] != top[y]) { if (depth[x] < depth[y]) swap(x, y); res = 0; val(1, tid[top[x]], tid[x]); ans = (ans + res) % mod; x = fa[top[x]]; } if (depth[x] > depth[y]) swap(x, y); res = 0; val(1, tid[x], tid[y]); ans = (ans + res) % mod; return ans; } void uroad(int x, int y, int k) { while (top[x] != top[y]) { if (depth[x] < depth[y]) swap(x, y); upd(1, tid[top[x]], tid[x], k); x = fa[top[x]]; } if (depth[x] > depth[y]) swap(x, y); upd(1, tid[x], tid[y], k); } int qtree(int x) { res = 0; val(1, tid[x], tid[x] + sz[x] - 1); return res; } void utree(int x, int k) { //cout << "1!!\n"; upd(1, tid[x], tid[x] + sz[x] - 1, k); //cout << "2!!\n"; } signed main() { // freopen("input.in", "r", stdin); // freopen("ans.txt", "w", stdout); scanf("%lld%lld%lld%lld", &n, &m, &r, &mod); int u, v, op, x, y, k; // cout << "1!!\n"; for (int i = 1; i <= n; i++) scanf("%lld", &w[i]); // cout << "2!!\n"; for (int i = 1; i < n; i++) { scanf("%lld%lld", &u, &v); add(u, v); add(v, u); } // cout << "3!!\n"; dfs1(r, 0, 1); // cout << "4!!\n"; dfs2(r, r); //cout << "5!!\n"; build(1, 1, n); //cout << "6!!\n"; while (m--) { scanf("%lld", &op); if (op == 1) { scanf("%lld%lld%lld", &x, &y, &k); uroad(x, y, k); } else if (op == 2) { scanf("%lld%lld", &x, &y); printf("%lld\n", qroad(x, y)); } else if (op == 3) { scanf("%lld%lld", &x, &k); utree(x, k); } else if (op == 4) { scanf("%lld", &x); printf("%lld\n", qtree(x)); } } //cout << "7!!\n"; return 0; } ```
by farfar @ 2022-07-03 14:19:19


`if (depth[x] < depth[y])` 这里是`if (depth[top[x]] < depth[top[y]])`(别问我为什么能看出来,亲身经历
by ningago @ 2022-07-03 14:23:39


@[farfar](/user/378951)
by ningago @ 2022-07-03 14:23:49


@[ningago](/user/371968) 感谢大佬!
by farfar @ 2022-07-03 15:17:37


|