线段树合并 & 分裂(二)
ShiRoZeTsu
·
·
个人记录
有了第一篇,所以就要有第二篇啦
qwq
以下是习题:
- P3899 [湖南集训] 更为厉害
- CF1009F Dominant Indices
- CF570D Tree Requests
- CF246E Blood Cousins Return
\
P3899 [湖南集训] 更为厉害
简要题意:给定一棵有根树,每次给定一组 a, k,求有多少个二元组 (b, c) 满足:
解法
由于 k 会发生变化,所以先预处理再回答询问的思路肯定行不通,因为至少要 \mathcal{O}(n^2)。
考虑把 a 固定,不难发现 b 要么是 a 的祖先,要么是 a 的后代。可以分情况讨论一下:
然后我们就可以统计答案:
-
对于第一种情况,不妨设 a 的深度为 dep_a(根节点深度为 1),那么 a 的上方,合法的 b 的数量共有 \min(dep_a - 1, k) 个。那么这样的 (b, c) 共有 \min(dep_a - 1, k) \times (sze_a - 1) 个。
-
对于第二种情况,不妨设 a 的深度为 dep_a(根节点深度为 1),那么 a 的下方,合法的 b 的深度一定在 dep_a + 1 至 dep_a + k 之间,因此这样的 (b, c) 共有 \sum_{dep_a + 1 \leq dep_b \leq dep_a + k} (sze_b - 1) 个。
所以总的答案为:
\min(dep_a - 1, k) \times (sze_a - 1) + \sum_{dep_a + 1 \leq dep_b \leq dep_a + k} (sze_b - 1)
前面那个东西可以直接算,主要瓶颈在后半部分。
实际上后半部分可以看成一个区间求和,那么自然可以用线段树去维护。
然后在 dfs 的时候合并上来,并把当前点 u 的贡献加进去,也就是在 dep_u 的位置单点加 sze_u - 1。
代码很好写。
CODE:
#include <iostream>
#include <cstdio>
#include <vector>
using namespace std;
typedef long long ll;
const int maxn = 3e5 + 5;
int n, q;
int dep[maxn], sze[maxn], root[maxn];
ll ans[maxn];
struct edge {
int to, nxt;
} e[maxn<<1];
int tot = 1, head[maxn];
void addedge(int u, int v) {
e[++tot].to = v;
e[tot].nxt = head[u];
head[u] = tot;
}
struct node {
int k, id;
};
vector<node> ask[maxn];
#define mid ((l + r) >> 1)
int cnt, top;
int stk[maxn<<5];
int ls[maxn<<2], rs[maxn<<2];
ll sum[maxn<<2];
int newnode() {
if(top) {
int x = stk[top];
top--;
return x;
}
else return ++cnt;
}
void del(int o) {
ls[o] = rs[o] = sum[o] = 0;
stk[++top] = o;
}
void merge(int& o, int p, int l, int r) {
if(!o || !p) { o += p; return; }
if(l == r) { sum[o] += sum[p]; del(p); return; }
merge(ls[o], ls[p], l, mid);
merge(rs[o], rs[p], mid+1, r);
sum[o] = sum[ls[o]] + sum[rs[o]];
del(p);
}
void modify(int& o, int l, int r, int pos, ll val) {
if(!o) o = newnode();
if(l == r) { sum[o] += val; return; }
if(pos <= mid) modify(ls[o], l, mid, pos, val);
else modify(rs[o], mid+1, r, pos, val);
sum[o] = sum[ls[o]] + sum[rs[o]];
}
ll query(int& o, int l, int r, int ql, int qr) {
if(ql <= l && r <= qr) return sum[o];
ll res = 0;
if(ql <= mid) res += query(ls[o], l, mid, ql, qr);
if(mid < qr) res += query(rs[o], mid+1, r, ql, qr);
return res;
}
void dfs(int u, int fa) {
dep[u] = dep[fa] + 1;
sze[u] = 1;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa) continue;
dfs(v, u);
sze[u] += sze[v];
merge(root[u], root[v], 1, n);
}
modify(root[u], 1, n, dep[u], sze[u]-1);
for(int i = 0; i < ask[u].size(); i++) {
int k = ask[u][i].k, id = ask[u][i].id;
ans[id] = 1ll * min(dep[u]-1, k) * (sze[u] - 1);
ans[id] += query(root[u], 1, n, dep[u]+1, min(dep[u]+k, n));
}
}
int main() {
scanf("%d %d", &n, &q);
for(int i = 1; i < n; i++) {
int u, v;
scanf("%d %d", &u, &v);
addedge(u, v); addedge(v, u);
}
for(int i = 1; i <= q; i++) {
int p, k;
scanf("%d %d", &p, &k);
ask[p].push_back((node){k, i});
}
dfs(1, 0);
for(int i = 1; i <= q; i++)
printf("%lld\n", ans[i]);
return 0;
}
\
CF1009F Dominant Indices
简要题意:给定一棵以 1 为根,n 个节点的树。设 d(u,x) 为 u 子树中到 u 距离为 x 的节点数。 对于每个点,求一个最小的 k,使得 d(u,k) 最大。
解法
这道题相对于上一道题来说难度小了不少,还是很板子的。
具体做法就是,用线段树维护子树中,每种深度的点的出现次数,以及出现次数最多的最小深度。
然后 dfs,合并线段树之后用查询的答案减去当前点的深度即可。
不过这道题有一个小亮点:1 \leq n \leq 10^6。
在第一次写完提交之后,我在 #131 RE 了好几次,最后发现空间开小了,但是实际上我已经全部开到了 maxn \times 2^5 空间理应来说已经很大了。
哪里出问题了呢?
我仔细地检查了我的代码,发现了这样一个细节:
void dfs(int u, int fa) {
dep[u] = dep[fa] + 1;
//* build(root[u], 1, n, dep[u]);
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa) continue;
dfs(v, u);
merge(root[u], root[v], 1, n);
}
ans[u] = mn[root[u]] - dep[u];
}
注意打 * 的那一行,每次调用 build 函数都会新建一棵线段树。由于只有一个点,所以会占用 \log n 的空间。
然后我们考虑,当树退化成一条链的时候,会发生什么。
每递归一层就占用 $\log n$ 的空间,当递归到最下端时,已经占用了 $n \times \log n$ 的空间!当 $n = 10^6$ 时,这个数字非常恐怖。
再来看更改之后的写法:
```cpp
void dfs(int u, int fa) {
dep[u] = dep[fa] + 1;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa) continue;
dfs(v, u);
merge(root[u], root[v], 1, n);
}
//* modify(root[u], 1, n, dep[u]);
ans[u] = mn[root[u]] - dep[u];
}
```
这里面的 $modify$ 和 $build$ 作用类似,都是在 $dep_u$ 的位置插入 $1$。
上下两种代码有什么区别?
你应该注意到了,上面的代码是**先新建再合并**,但后面这一种是**先合并,再在原有的树上添加**。后者会将重复的空间利用起来,明显是更优秀的。
所以这也算是一个小要点了:**一定要先合并,后添加**。
其它的部分很好写,不过这个点确实值得积累。
CODE:
```cpp
#include <iostream>
#include <cstdio>
using namespace std;
const int maxn = 1e6 + 5;
int n;
int ans[maxn], dep[maxn], root[maxn];
struct edge {
int to, nxt;
} e[maxn<<1];
int tot = 1, head[maxn];
void addedge(int u, int v) {
e[++tot].to = v;
e[tot].nxt = head[u];
head[u] = tot;
}
#define mid ((l + r) >> 1)
int cnt, top;
int mn[maxn<<2], sum[maxn<<2];
int ls[maxn<<2], rs[maxn<<2], stk[maxn<<5];
int newnode() {
if(top) {
int x = stk[top];
top--;
return x;
}
else return ++cnt;
}
void del(int o) {
ls[o] = rs[o] = mn[o] = sum[o] = 0;
stk[++top] = o;
}
void modify(int& o, int l, int r, int pos) {
if(!o) o = newnode();
if(l == r) { sum[o]++; mn[o] = pos; return; }
if(pos <= mid) modify(ls[o], l, mid, pos);
else modify(rs[o], mid+1, r, pos);
if(sum[ls[o]] >= sum[rs[o]]) mn[o] = mn[ls[o]];
else mn[o] = mn[rs[o]];
sum[o] = max(sum[ls[o]], sum[rs[o]]);
}
void merge(int& o, int p, int l, int r) {
if(!o || !p) { o += p; return; }
if(l == r) { sum[o] += sum[p]; del(p); return; }
merge(ls[o], ls[p], l, mid);
merge(rs[o], rs[p], mid+1, r);
if(sum[ls[o]] >= sum[rs[o]]) mn[o] = mn[ls[o]];
else mn[o] = mn[rs[o]];
sum[o] = max(sum[ls[o]], sum[rs[o]]);
del(p);
}
void dfs(int u, int fa) {
dep[u] = dep[fa] + 1;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa) continue;
dfs(v, u);
merge(root[u], root[v], 1, n);
}
modify(root[u], 1, n, dep[u]);
ans[u] = mn[root[u]] - dep[u];
}
int main() {
scanf("%d", &n);
for(int i = 1; i < n; i++) {
int u, v;
scanf("%d %d", &u, &v);
addedge(u, v); addedge(v, u);
}
dfs(1, 0);
for(int i = 1; i <= n; i++)
printf("%d\n", ans[i]);
return 0;
}
```
$$
\
$$
## [CF570D Tree Requests](https://www.luogu.com.cn/problem/CF570D)
> 简要题意:
给定一个有根树,每个点上有一个小写字母。每次询问 $a, b$ 查询以 $a$ 为根的子树内,深度为 $b$ 的所有结点上的字母,重新排列之后是否能构成回文串。
#### 解法
首先,一个字符串重新排列后是否能成为回文串,只需要看其中出现奇数次的字符是否小于 $2$ 就可以了。
沿用之前的思路,利用线段树合并,把询问离线下来,在 dfs 的过程中求解。
但是有一个问题,如何维护一个深度处,所有字母的出现次数?
要注意,我们只在乎这些字母的**奇偶性**,所以可以用 $1$ 和 $0$ 来表示一个字符是否出现奇数次。
其次,由于字符数量只有 $26$,可以考虑压缩进一个**二进制数**里,从低到高位分别表示字母 `a`,`b`,`c`......是否出现奇数次。
这样数字的值域就是 $[0, 2^{26}-1]$,`int` 就够用了。
然后考虑怎么修改,实际上就是在字母对应位置异或一下就可以了。
至于查询,只要看对应位置的数字的二进制形式中,数字 $1$ 出现次数是否 $\leq 1$ 就可以了。具体来说,要么 $x$ 是 $0$,要么 $x - lowbit(x) = 0$,这两种情况都可以。
CODE:
```cpp
#include <iostream>
#include <cstdio>
#include <vector>
using namespace std;
const int maxn = 5e5 + 5;
int n, m;
int val[maxn], dep[maxn], root[maxn], mxdep[maxn];
char s[maxn];
bool ok[maxn];
struct edge {
int to, nxt;
} e[maxn];
int tot = 1, head[maxn];
void addedge(int u, int v) {
e[++tot].to = v;
e[tot].nxt = head[u];
head[u] = tot;
}
struct node {
int k, id;
};
vector<node> ask[maxn];
#define mid ((l + r) >> 1)
int cnt, top;
int ls[maxn<<2], rs[maxn<<2], sum[maxn<<2], stk[maxn<<5];
int lowbit(int x) { return x & (-x); }
bool check(int x) {
if(!x || x - lowbit(x) == 0) return true;
return false;
}
int newnode() {
if(top) {
int x = stk[top];
top--;
return x;
}
else return ++cnt;
}
void del(int o) {
ls[o] = rs[o] = sum[o] = 0;
stk[++top] = o;
}
void merge(int& o, int p, int l, int r) {
if(!o || !p) { o += p; return; }
if(l == r) { sum[o] ^= sum[p]; del(p); return; }
merge(ls[o], ls[p], l, mid);
merge(rs[o], rs[p], mid+1, r);
del(p);
}
void modify(int& o, int l, int r, int pos, int v) {
if(!o) o = newnode();
if(l == r) { sum[o] ^= v; return; }
if(pos <= mid) modify(ls[o], l, mid, pos, v);
else modify(rs[o], mid+1, r, pos, v);
}
bool query(int o, int l, int r, int pos) {
if(l == r) return check(sum[o]);
if(pos <= mid) return query(ls[o], l, mid, pos);
else return query(rs[o], mid+1, r, pos);
}
void dfs(int u, int fa) {
dep[u] = dep[fa] + 1;
mxdep[u] = dep[u];
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
dfs(v, u);
merge(root[u], root[v], 1, n);
mxdep[u] = max(mxdep[u], mxdep[v]);
}
modify(root[u], 1, n, dep[u], val[u]);
for(int i = 0; i < ask[u].size(); i++) {
int k = ask[u][i].k, id = ask[u][i].id;
ok[id] |= query(root[u], 1, n, k);
}
}
int main() {
scanf("%d %d", &n, &m);
for(int i = 2; i <= n; i++) {
int fa; scanf("%d", &fa);
addedge(fa, i);
}
scanf("%s", s+1);
for(int i = 1; i <= n; i++)
val[i] = (1 << (s[i]-'a'));
for(int i = 1; i <= m; i++) {
int u, k;
scanf("%d %d", &u, &k);
ask[u].push_back((node){k, i});
}
dfs(1, 0);
for(int i = 1; i <= m; i++)
printf(ok[i] ? "Yes\n" : "No\n");
return 0;
}
```
$$
\
$$
## [CF246E Blood Cousins Return](https://www.luogu.com.cn/problem/CF246E)
> 简要题意:给定一个森林,每个点上都有一个字符串。每次询问一个点的 k-son 有多少个不同的字符串。
#### 解法
像前几道题的套路一样,还是对深度开一个线段树,每个位置存储出现的所有字符串。
而字符串可以哈希,或者用 `map`搞。线段树上的叶子节点可以开一个 `set`,查询的时候输出单点的 `set` 的 `size` 即可。
CODE:
```cpp
#include <iostream>
#include <cstdio>
#include <map>
#include <vector>
#include <set>
using namespace std;
const int maxn = 1e5 + 5;
int n, m, stot;
int w[maxn], dep[maxn], root[maxn], f[maxn], mxdep[maxn];
unsigned long long ans[maxn];
map<string, int> mp;
struct edge {
int to, nxt;
} e[maxn];
int tot = 1, head[maxn];
void addedge(int u, int v) {
e[++tot].to = v;
e[tot].nxt = head[u];
head[u] = tot;
}
struct node {
int k, id;
};
vector<node> ask[maxn];
#define mid ((l + r) >> 1)
int cnt, top;
int ls[maxn<<2], rs[maxn<<2], stk[maxn<<5];
set<int> val[maxn<<2];
void mergeset(set<int>& x, set<int> y) {
for(auto it = y.begin(); it != y.end(); it++)
x.insert(*it);
}
int newnode() {
if(top) {
int x = stk[top];
top--;
return x;
}
else return ++cnt;
}
void del(int o) {
ls[o] = rs[o] = 0;
val[o].clear();
stk[++top] = o;
}
void merge(int& o, int p, int l, int r) {
if(!o || !p) { o += p; return; }
if(l == r) { mergeset(val[o], val[p]); del(p); return; }
merge(ls[o], ls[p], l, mid);
merge(rs[o], rs[p], mid+1, r);
del(p);
}
void modify(int& o, int l, int r, int pos, int x) {
if(!o) o = newnode();
if(l == r) { val[o].insert(x); return; }
if(pos <= mid) modify(ls[o], l, mid, pos, x);
else modify(rs[o], mid+1, r, pos, x);
}
unsigned long long query(int o, int l, int r, int pos) {
if(l == r) return val[o].size();
if(pos <= mid) return query(ls[o], l, mid, pos);
else return query(rs[o], mid+1, r, pos);
}
void dfs(int u, int fa) {
dep[u] = dep[fa] + 1;
mxdep[u] = dep[u];
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
dfs(v, u);
merge(root[u], root[v], 1, n);
mxdep[u] = max(mxdep[u], mxdep[v]);
}
modify(root[u], 1, n, dep[u], w[u]);
for(int i = 0; i < ask[u].size(); i++) {
int k = ask[u][i].k, id = ask[u][i].id;
if(dep[u]+k <= mxdep[u])
ans[id] = query(root[u], 1, n, dep[u]+k);
}
}
void deltree(int u) {
if(ls[u]) deltree(ls[u]);
if(rs[u]) deltree(rs[u]);
del(u);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
cin >> n;
string s;
int fa;
for(int i = 1; i <= n; i++) {
cin >> s;
if(!mp[s]) mp[s] = ++stot;
w[i] = mp[s];
cin >> f[i];
if(f[i]) addedge(f[i], i);
}
cin >> m;
int u, k;
for(int i = 1; i <= m; i++) {
cin >> u >> k;
ask[u].push_back((node){k, i});
}
for(int i = 1; i <= n; i++)
if(!f[i]) {
dfs(i, 0);
deltree(root[i]);
}
for(int i = 1; i <= m; i++)
cout << ans[i] << '\n';
return 0;
}
```
暂时先写到这。
$$
qwq
$$