题解:P12511 [集训队互测 2024] 树上简单求和
yishanyi
·
·
题解
P12511 [集训队互测 2024] 树上简单求和
pro.
给定两棵 n 个点的有编号无根树(形态不保证相同),点的编号从 1 到 n,点 i 的点权为 a_i,两棵树共用点权。
有 m 次操作,每次操作给定 x,y,k,进行两步:
先对第一棵树上 x 到 y 的简单路径上的所有点的权值增加 k ;
再求出第二棵树上 x 到 y 的简单路径上的所有点的权值和,对 2^{64} 取模。
## sol.
> 首先如果树是一条链,那么就变成了二维平面 $x$ 区间加, $y$ 区间和,可以规约矩阵乘法。所以本题显然不存在 $\mathrm{poly} \log$ 的做法。
~~但我不会规约矩乘。~~
首先通过树剖可以把树链操作转化为 $\mathcal{O}(\log n)$ 个 `dfn` 上的区间问题。
假设我们已经实现了一个数据结构,支持在 $\mathcal{O}(x)$ 的复杂度内对第二个序列单点修改,在 $\mathcal{O}(y)$ 的复杂度内对第二个序列区间查询。
考虑如何修改:对第一个序列分块,散块直接暴力在第二个序列上单点修改,查询时通过第二个序列区间查询,共 $\mathcal{O}(m\sqrt n)$ 次修改, $\mathcal{O}(m)$ 次查询,故散块修改复杂度为 $\mathcal{O}(mx\sqrt n)$,查询复杂度为 $\mathcal{O}(my)$。为了根号平衡,我们发现,第二个序列上的数据结构应该支持 $\mathcal{O}(1)$ 单点修改, $\mathcal{O}(\sqrt n)$ 区间查询——还是分块。
整块修改依然考虑延迟标记,问题在于如何在查询时加入影响。
可以前缀和,对于第一个序列的每个块,我们 $\mathcal{O}(n)$ 预处理出第二个序列的每个前缀有多少个点落在该块内,区间查询时该块的对答案的影响为 $tag_i\cdot (pre_{i,r}-pre_{i,l-1})$。整块修改和查询都有 $\mathcal{O}(m\sqrt n)$ 次,单次 $\mathcal{O}(1)$,故整块修改和查询复杂度均为 $\mathcal{O}(m\sqrt n)$。
于是算上树剖就可以在 $\mathcal{O}(m\sqrt n\log n)$ 的时间内解决。
~~然后就被卡空间了(~~
但空间复杂度 $\mathcal{O}(n\sqrt n)\approx9\times10^7$,大概在 $\mathrm{350MB}$ 左右。
怎么优化空间呢?发现如果不维护序列的每个前缀,而仅维护整块的每个前缀,那么空间就变成了 $\mathcal{O}(n)$ 的。
于是我们可以在 $\mathcal{O}(\sqrt n)$ 的时间内解决整块对整块的影响,还需处理整块对散块的影响。
我们惊喜的发现:对于每个点有且仅有一个块将其包含(不知道这个很显然的东西我为什么想了很久才想到),于是对每个散块的影响容易 $\mathcal{O}(1)$ 查询。
然后做完了。时间 $\mathcal{O}(m\sqrt n\log n)$,空间 $\mathcal{O}(n)$。
实测块长 $200$ 左右比较快。
##### cod.
```cpp
constexpr int N = 2e5 + 1, B = 501;
int n, m, b, cnt, pre[B][B], belong[N];
ull a[N], tag[B], sum[B];
struct Tree {
std::vector<int> e[N];
int idx, fa[N], dep[N], siz[N], son[N], top[N], num[N], dfn[N];
void dfs(int u) {
dep[u] = dep[fa[u]] + 1, siz[u] = 1;
for (const int& v : e[u])
if (v != fa[u]) {
fa[v] = u, dfs(v), siz[u] += siz[v];
if (siz[v] > siz[son[u]]) son[u] = v;
}
}
void dfs(int u, int tp) {
top[u] = tp, dfn[num[u] = ++idx] = u;
if (son[u]) dfs(son[u], tp);
for (const int& v : e[u])
if (v != fa[u] && v != son[u])
dfs(v, v);
}
void init() {
for (int i = 1; i < n; i++) {
int u, v;
std::cin >> u >> v;
e[u].push_back(v), e[v].push_back(u);
}
dfs(1), dfs(1, 1);
}
} T1, T2;
void add(int l, int r, ull x) {
int L = belong[l], R = belong[r];
if (L == R) {
for (int i = l, j; i <= r; i++) {
j = T2.num[T1.dfn[i]];
a[T2.dfn[j]] += x, sum[belong[j]] += x;
}
return;
}
for (int i = l, j; belong[i] == L; i++) {
j = T2.num[T1.dfn[i]];
a[T2.dfn[j]] += x, sum[belong[j]] += x;
}
for (int i = L + 1; i < R; i++) tag[i] += x;
for (int i = r, j; belong[i] == R; i--) {
j = T2.num[T1.dfn[i]];
a[T2.dfn[j]] += x, sum[belong[j]] += x;
}
}
void Add(int u, int v, ull k) {
int x = u, y = v;
int *dep = T1.dep, *top = T1.top, *fa = T1.fa, *num = T1.num;
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) std::swap(u, v);
add(num[top[u]], num[u], k), u = fa[top[u]];
}
if (dep[u] > dep[v]) std::swap(u, v);
add(num[u], num[v], k);
int l = dep[u] < dep[v] ? u : v;
}
ull query(int l, int r) {
ull res = 0;
int L = belong[l], R = belong[r];
if (L == R) {
for (int i = l; i <= r; i++) res += a[T2.dfn[i]] + tag[belong[T1.num[T2.dfn[i]]]];
return res;
}
for (int i = l; belong[i] == L; i++) res += a[T2.dfn[i]] + tag[belong[T1.num[T2.dfn[i]]]];
for (int i = L + 1; i < R; i++) res += sum[i];
if (R > L + 1) for (int i = 1; i <= cnt; i++) res += tag[i] * (pre[i][R - 1] - pre[i][L]);
for (int i = r; belong[i] == R; i--) res += a[T2.dfn[i]] + tag[belong[T1.num[T2.dfn[i]]]];
return res;
}
ull Query(int u, int v) {
ull res = 0;
int *dep = T2.dep, *top = T2.top, *fa = T2.fa, *num = T2.num;
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) std::swap(u, v);
res += query(num[top[u]], num[u]), u = fa[top[u]];
}
if (dep[u] > dep[v]) std::swap(u, v);
return res + query(num[u], num[v]);
}
void solve() {
std::cin >> n >> m, b = map(int, sqrt(n));
for (int i = 1; i <= n; i++) std::cin >> a[i], belong[i] = (i - 1) / b + 1;
T1.init(), T2.init(), cnt = belong[n];
for (int i = 1; i <= n; i++) sum[belong[i]] += a[T2.dfn[i]], pre[belong[T1.num[i]]][belong[T2.num[i]]]++;
for (int i = 1; i <= cnt; i++) for (int j = 1; j <= cnt; j++) pre[i][j] += pre[i][j - 1];
while (m--) {
int x, y;
ull k;
std::cin >> x >> y >> k;
Add(x, y, k), std::cout << Query(x, y) << "\n";
}
}
```