线段树合并 & 分裂(二)

· · 个人记录

有了第一篇,所以就要有第二篇啦

qwq

以下是习题:

\

P3899 [湖南集训] 更为厉害

简要题意:给定一棵有根树,每次给定一组 a, k,求有多少个二元组 (b, c) 满足:

解法

由于 k 会发生变化,所以先预处理再回答询问的思路肯定行不通,因为至少要 \mathcal{O}(n^2)

考虑把 a 固定,不难发现 b 要么是 a祖先,要么是 a后代。可以分情况讨论一下:

然后我们就可以统计答案:

所以总的答案为:

\min(dep_a - 1, k) \times (sze_a - 1) + \sum_{dep_a + 1 \leq dep_b \leq dep_a + k} (sze_b - 1)

前面那个东西可以直接算,主要瓶颈在后半部分。

实际上后半部分可以看成一个区间求和,那么自然可以用线段树去维护。

然后在 dfs 的时候合并上来,并把当前点 u 的贡献加进去,也就是在 dep_u 的位置单点加 sze_u - 1

代码很好写。

CODE:

#include <iostream>
#include <cstdio>
#include <vector>
using namespace std;
typedef long long ll;

const int maxn = 3e5 + 5;

int n, q;
int dep[maxn], sze[maxn], root[maxn];
ll ans[maxn];

struct edge {
    int to, nxt;
} e[maxn<<1];

int tot = 1, head[maxn];
void addedge(int u, int v) {
    e[++tot].to = v;
    e[tot].nxt = head[u];
    head[u] = tot;
}

struct node {
    int k, id;
};
vector<node> ask[maxn];

#define mid ((l + r) >> 1)

int cnt, top;
int stk[maxn<<5];
int ls[maxn<<2], rs[maxn<<2];
ll sum[maxn<<2];

int newnode() {
    if(top) {
        int x = stk[top];
        top--;
        return x;
    }
    else return ++cnt;
}

void del(int o) {
    ls[o] = rs[o] = sum[o] = 0;
    stk[++top] = o;
}

void merge(int& o, int p, int l, int r) {
    if(!o || !p) { o += p; return; }
    if(l == r) { sum[o] += sum[p]; del(p); return; }
    merge(ls[o], ls[p], l, mid);
    merge(rs[o], rs[p], mid+1, r);
    sum[o] = sum[ls[o]] + sum[rs[o]];
    del(p);
}

void modify(int& o, int l, int r, int pos, ll val) {
    if(!o) o = newnode();
    if(l == r) { sum[o] += val; return; }
    if(pos <= mid) modify(ls[o], l, mid, pos, val);
    else modify(rs[o], mid+1, r, pos, val);
    sum[o] = sum[ls[o]] + sum[rs[o]];
}

ll query(int& o, int l, int r, int ql, int qr) {
    if(ql <= l && r <= qr) return sum[o];
    ll res = 0;
    if(ql <= mid) res += query(ls[o], l, mid, ql, qr);
    if(mid < qr) res += query(rs[o], mid+1, r, ql, qr);
    return res;
}

void dfs(int u, int fa) {
    dep[u] = dep[fa] + 1;
    sze[u] = 1;
    for(int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if(v == fa) continue;
        dfs(v, u);
        sze[u] += sze[v];
        merge(root[u], root[v], 1, n);
    }
    modify(root[u], 1, n, dep[u], sze[u]-1);
    for(int i = 0; i < ask[u].size(); i++) {
        int k = ask[u][i].k, id = ask[u][i].id;
        ans[id] = 1ll * min(dep[u]-1, k) * (sze[u] - 1);
        ans[id] += query(root[u], 1, n, dep[u]+1, min(dep[u]+k, n));
    }
}

int main() {
    scanf("%d %d", &n, &q);
    for(int i = 1; i < n; i++) {
        int u, v;
        scanf("%d %d", &u, &v);
        addedge(u, v); addedge(v, u);
    }
    for(int i = 1; i <= q; i++) {
        int p, k;
        scanf("%d %d", &p, &k);
        ask[p].push_back((node){k, i});
    }
    dfs(1, 0);

    for(int i = 1; i <= q; i++)
        printf("%lld\n", ans[i]);
    return 0;
}
\

CF1009F Dominant Indices

简要题意:给定一棵以 1 为根,n 个节点的树。设 d(u,x)u 子树中到 u 距离为 x 的节点数。 对于每个点,求一个最小的 k,使得 d(u,k) 最大。

解法

这道题相对于上一道题来说难度小了不少,还是很板子的。

具体做法就是,用线段树维护子树中,每种深度的点的出现次数,以及出现次数最多的最小深度。

然后 dfs,合并线段树之后用查询的答案减去当前点的深度即可。

不过这道题有一个小亮点:1 \leq n \leq 10^6

在第一次写完提交之后,我在 #131 RE 了好几次,最后发现空间开小了,但是实际上我已经全部开到了 maxn \times 2^5 空间理应来说已经很大了。

哪里出问题了呢?

我仔细地检查了我的代码,发现了这样一个细节:

void dfs(int u, int fa) {
    dep[u] = dep[fa] + 1;
//* build(root[u], 1, n, dep[u]);
    for(int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if(v == fa) continue;
        dfs(v, u);
        merge(root[u], root[v], 1, n);
    }
    ans[u] = mn[root[u]] - dep[u];
}

注意打 * 的那一行,每次调用 build 函数都会新建一棵线段树。由于只有一个点,所以会占用 \log n 的空间。

然后我们考虑,当树退化成一条链的时候,会发生什么。

每递归一层就占用 $\log n$ 的空间,当递归到最下端时,已经占用了 $n \times \log n$ 的空间!当 $n = 10^6$ 时,这个数字非常恐怖。 再来看更改之后的写法: ```cpp void dfs(int u, int fa) { dep[u] = dep[fa] + 1; for(int i = head[u]; i; i = e[i].nxt) { int v = e[i].to; if(v == fa) continue; dfs(v, u); merge(root[u], root[v], 1, n); } //* modify(root[u], 1, n, dep[u]); ans[u] = mn[root[u]] - dep[u]; } ``` 这里面的 $modify$ 和 $build$ 作用类似,都是在 $dep_u$ 的位置插入 $1$。 上下两种代码有什么区别? 你应该注意到了,上面的代码是**先新建再合并**,但后面这一种是**先合并,再在原有的树上添加**。后者会将重复的空间利用起来,明显是更优秀的。 所以这也算是一个小要点了:**一定要先合并,后添加**。 其它的部分很好写,不过这个点确实值得积累。 CODE: ```cpp #include <iostream> #include <cstdio> using namespace std; const int maxn = 1e6 + 5; int n; int ans[maxn], dep[maxn], root[maxn]; struct edge { int to, nxt; } e[maxn<<1]; int tot = 1, head[maxn]; void addedge(int u, int v) { e[++tot].to = v; e[tot].nxt = head[u]; head[u] = tot; } #define mid ((l + r) >> 1) int cnt, top; int mn[maxn<<2], sum[maxn<<2]; int ls[maxn<<2], rs[maxn<<2], stk[maxn<<5]; int newnode() { if(top) { int x = stk[top]; top--; return x; } else return ++cnt; } void del(int o) { ls[o] = rs[o] = mn[o] = sum[o] = 0; stk[++top] = o; } void modify(int& o, int l, int r, int pos) { if(!o) o = newnode(); if(l == r) { sum[o]++; mn[o] = pos; return; } if(pos <= mid) modify(ls[o], l, mid, pos); else modify(rs[o], mid+1, r, pos); if(sum[ls[o]] >= sum[rs[o]]) mn[o] = mn[ls[o]]; else mn[o] = mn[rs[o]]; sum[o] = max(sum[ls[o]], sum[rs[o]]); } void merge(int& o, int p, int l, int r) { if(!o || !p) { o += p; return; } if(l == r) { sum[o] += sum[p]; del(p); return; } merge(ls[o], ls[p], l, mid); merge(rs[o], rs[p], mid+1, r); if(sum[ls[o]] >= sum[rs[o]]) mn[o] = mn[ls[o]]; else mn[o] = mn[rs[o]]; sum[o] = max(sum[ls[o]], sum[rs[o]]); del(p); } void dfs(int u, int fa) { dep[u] = dep[fa] + 1; for(int i = head[u]; i; i = e[i].nxt) { int v = e[i].to; if(v == fa) continue; dfs(v, u); merge(root[u], root[v], 1, n); } modify(root[u], 1, n, dep[u]); ans[u] = mn[root[u]] - dep[u]; } int main() { scanf("%d", &n); for(int i = 1; i < n; i++) { int u, v; scanf("%d %d", &u, &v); addedge(u, v); addedge(v, u); } dfs(1, 0); for(int i = 1; i <= n; i++) printf("%d\n", ans[i]); return 0; } ``` $$ \ $$ ## [CF570D Tree Requests](https://www.luogu.com.cn/problem/CF570D) > 简要题意: 给定一个有根树,每个点上有一个小写字母。每次询问 $a, b$ 查询以 $a$ 为根的子树内,深度为 $b$ 的所有结点上的字母,重新排列之后是否能构成回文串。 #### 解法 首先,一个字符串重新排列后是否能成为回文串,只需要看其中出现奇数次的字符是否小于 $2$ 就可以了。 沿用之前的思路,利用线段树合并,把询问离线下来,在 dfs 的过程中求解。 但是有一个问题,如何维护一个深度处,所有字母的出现次数? 要注意,我们只在乎这些字母的**奇偶性**,所以可以用 $1$ 和 $0$ 来表示一个字符是否出现奇数次。 其次,由于字符数量只有 $26$,可以考虑压缩进一个**二进制数**里,从低到高位分别表示字母 `a`,`b`,`c`......是否出现奇数次。 这样数字的值域就是 $[0, 2^{26}-1]$,`int` 就够用了。 然后考虑怎么修改,实际上就是在字母对应位置异或一下就可以了。 至于查询,只要看对应位置的数字的二进制形式中,数字 $1$ 出现次数是否 $\leq 1$ 就可以了。具体来说,要么 $x$ 是 $0$,要么 $x - lowbit(x) = 0$,这两种情况都可以。 CODE: ```cpp #include <iostream> #include <cstdio> #include <vector> using namespace std; const int maxn = 5e5 + 5; int n, m; int val[maxn], dep[maxn], root[maxn], mxdep[maxn]; char s[maxn]; bool ok[maxn]; struct edge { int to, nxt; } e[maxn]; int tot = 1, head[maxn]; void addedge(int u, int v) { e[++tot].to = v; e[tot].nxt = head[u]; head[u] = tot; } struct node { int k, id; }; vector<node> ask[maxn]; #define mid ((l + r) >> 1) int cnt, top; int ls[maxn<<2], rs[maxn<<2], sum[maxn<<2], stk[maxn<<5]; int lowbit(int x) { return x & (-x); } bool check(int x) { if(!x || x - lowbit(x) == 0) return true; return false; } int newnode() { if(top) { int x = stk[top]; top--; return x; } else return ++cnt; } void del(int o) { ls[o] = rs[o] = sum[o] = 0; stk[++top] = o; } void merge(int& o, int p, int l, int r) { if(!o || !p) { o += p; return; } if(l == r) { sum[o] ^= sum[p]; del(p); return; } merge(ls[o], ls[p], l, mid); merge(rs[o], rs[p], mid+1, r); del(p); } void modify(int& o, int l, int r, int pos, int v) { if(!o) o = newnode(); if(l == r) { sum[o] ^= v; return; } if(pos <= mid) modify(ls[o], l, mid, pos, v); else modify(rs[o], mid+1, r, pos, v); } bool query(int o, int l, int r, int pos) { if(l == r) return check(sum[o]); if(pos <= mid) return query(ls[o], l, mid, pos); else return query(rs[o], mid+1, r, pos); } void dfs(int u, int fa) { dep[u] = dep[fa] + 1; mxdep[u] = dep[u]; for(int i = head[u]; i; i = e[i].nxt) { int v = e[i].to; dfs(v, u); merge(root[u], root[v], 1, n); mxdep[u] = max(mxdep[u], mxdep[v]); } modify(root[u], 1, n, dep[u], val[u]); for(int i = 0; i < ask[u].size(); i++) { int k = ask[u][i].k, id = ask[u][i].id; ok[id] |= query(root[u], 1, n, k); } } int main() { scanf("%d %d", &n, &m); for(int i = 2; i <= n; i++) { int fa; scanf("%d", &fa); addedge(fa, i); } scanf("%s", s+1); for(int i = 1; i <= n; i++) val[i] = (1 << (s[i]-'a')); for(int i = 1; i <= m; i++) { int u, k; scanf("%d %d", &u, &k); ask[u].push_back((node){k, i}); } dfs(1, 0); for(int i = 1; i <= m; i++) printf(ok[i] ? "Yes\n" : "No\n"); return 0; } ``` $$ \ $$ ## [CF246E Blood Cousins Return](https://www.luogu.com.cn/problem/CF246E) > 简要题意:给定一个森林,每个点上都有一个字符串。每次询问一个点的 k-son 有多少个不同的字符串。 #### 解法 像前几道题的套路一样,还是对深度开一个线段树,每个位置存储出现的所有字符串。 而字符串可以哈希,或者用 `map`搞。线段树上的叶子节点可以开一个 `set`,查询的时候输出单点的 `set` 的 `size` 即可。 CODE: ```cpp #include <iostream> #include <cstdio> #include <map> #include <vector> #include <set> using namespace std; const int maxn = 1e5 + 5; int n, m, stot; int w[maxn], dep[maxn], root[maxn], f[maxn], mxdep[maxn]; unsigned long long ans[maxn]; map<string, int> mp; struct edge { int to, nxt; } e[maxn]; int tot = 1, head[maxn]; void addedge(int u, int v) { e[++tot].to = v; e[tot].nxt = head[u]; head[u] = tot; } struct node { int k, id; }; vector<node> ask[maxn]; #define mid ((l + r) >> 1) int cnt, top; int ls[maxn<<2], rs[maxn<<2], stk[maxn<<5]; set<int> val[maxn<<2]; void mergeset(set<int>& x, set<int> y) { for(auto it = y.begin(); it != y.end(); it++) x.insert(*it); } int newnode() { if(top) { int x = stk[top]; top--; return x; } else return ++cnt; } void del(int o) { ls[o] = rs[o] = 0; val[o].clear(); stk[++top] = o; } void merge(int& o, int p, int l, int r) { if(!o || !p) { o += p; return; } if(l == r) { mergeset(val[o], val[p]); del(p); return; } merge(ls[o], ls[p], l, mid); merge(rs[o], rs[p], mid+1, r); del(p); } void modify(int& o, int l, int r, int pos, int x) { if(!o) o = newnode(); if(l == r) { val[o].insert(x); return; } if(pos <= mid) modify(ls[o], l, mid, pos, x); else modify(rs[o], mid+1, r, pos, x); } unsigned long long query(int o, int l, int r, int pos) { if(l == r) return val[o].size(); if(pos <= mid) return query(ls[o], l, mid, pos); else return query(rs[o], mid+1, r, pos); } void dfs(int u, int fa) { dep[u] = dep[fa] + 1; mxdep[u] = dep[u]; for(int i = head[u]; i; i = e[i].nxt) { int v = e[i].to; dfs(v, u); merge(root[u], root[v], 1, n); mxdep[u] = max(mxdep[u], mxdep[v]); } modify(root[u], 1, n, dep[u], w[u]); for(int i = 0; i < ask[u].size(); i++) { int k = ask[u][i].k, id = ask[u][i].id; if(dep[u]+k <= mxdep[u]) ans[id] = query(root[u], 1, n, dep[u]+k); } } void deltree(int u) { if(ls[u]) deltree(ls[u]); if(rs[u]) deltree(rs[u]); del(u); } int main() { ios::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr); cin >> n; string s; int fa; for(int i = 1; i <= n; i++) { cin >> s; if(!mp[s]) mp[s] = ++stot; w[i] = mp[s]; cin >> f[i]; if(f[i]) addedge(f[i], i); } cin >> m; int u, k; for(int i = 1; i <= m; i++) { cin >> u >> k; ask[u].push_back((node){k, i}); } for(int i = 1; i <= n; i++) if(!f[i]) { dfs(i, 0); deltree(root[i]); } for(int i = 1; i <= m; i++) cout << ans[i] << '\n'; return 0; } ``` 暂时先写到这。 $$ qwq $$