树链剖分 学习报告

· · 题解

本题作为最终例题,代码在最下方给出。

众所周知,树链剖分是一个非常强的OIer(大雾

一些定义

重儿子:子树最大的儿子节点

轻儿子:其他儿子节点

重链:由重儿子组成的链

轻链:其他链

原理及实现

剖分结果

首先这是一颗

先不谈原理,直接将剖分后的树拿出来

其中加重的边构成重链,重儿子被染成红色,轻儿子被染成蓝色

注意1112号节点,他们的子树大小相同,此时我们仅将节点编号的作为重儿子。

对着上面的两幅图可以更好的理解“重”的概念。

剖分过程

我们通过两次dfs进行剖分。

规定变量

/*默认邻接表存边*/
sz[x] // 以x节点为根的子树大小
d[x] // x节点的深度
fa[x] // x节点的父节点
son[x] // x节点的重儿子(注意是重儿子)
tp[x] // x节点所在链的链顶

开始剖分

第一遍dfs处理出每个节点的深度和子树大小,从而处理出每个节点的重儿子。

void dfs1(int x) {
    int maxx = -1; // 记录重儿子的子树大小以更新
    for(register int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if(y == fa[x]) continue;
        d[y] = d[x] + 1;
        fa[y] = x;
        sz[x]++;
        dfs1(y);
        sz[x] += sz[y];
        if(sz[y] > maxx) {
            maxx = sz[y];
            son[x] = y;
        }
    }
}

第二遍dfs处理出每个节点所在链的链顶,注意这里的链仅指重链,规定轻链上每个点的链顶都为自己。

void dfs2(int x, int top) {
    tp[x] = top;
    //这里要优先遍历重儿子
    if(son[x]) dfs2(son[x], top);
    for(register int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if(y == son[x] || y == fa[x]) continue;
        dfs2(y, y);
    }
}

至此,我们已经完成了树链剖分。

处理问题

那么,树剖可以处理哪些问题呢?

在线LCA

考虑每条重链,根据定义,我们发现每个点若链顶相同,则LCA为深度较小的点。

可以结合上面的图理解一下。

有了这个性质,我们不妨尝试将每个点都转移到链顶相同的点,于是就有了下面的代码。

inline int lca(int x, int y) {
    int tp1 = tp[x], tp2 = tp[y];
    while(tp1 != tp2) {
        if(d[tp1] > d[tp2]) {
            std::swap(tp1, tp2);
            std::swap(x, y);
        }
        y = fa[tp2];
        tp2 = tp[y];
    }
    return d[x] < d[y] ? x : y;
}

复杂度为O(\log n)

例题:P3379 【模板】最近公共祖先(LCA)

板子题,直接套上面的代码。

#include <cstdio>
#include <cctype>

namespace FastIO {
    inline int read() {
        char ch = getchar(); int r = 0, w = 1;
        while(!isdigit(ch)) {if(ch == '-') w = -1; ch = getchar();}
        while(isdigit(ch)) {r = r * 10 + ch - '0', ch = getchar();}
        return r * w;
    }
    void _write(int x) {
        if(x < 0) putchar('-'), x = -x;
        if(x > 9) _write(x / 10);
        putchar(x % 10 + '0');
    }
    inline void write(int x) {
        _write(x);
        puts("");
    }
}

using namespace FastIO;

template <typename T> inline void swap(T &x, T &y) {T tmp = x; x = y, y = tmp;}

const int N = 5e5 + 6;
const int M = N << 1;

int head[N], nxt[M], ver[M], cnt;
int tp[N], sz[N], fa[N], son[N];
int d[N];

inline void add(int x, int y) {
    ver[++cnt] = y, nxt[cnt] = head[x], head[x] = cnt;
}

void dfs1(int x) {
    int maxx = -1;
    for(register int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if(y == fa[x]) continue;
        d[y] = d[x] + 1;
        sz[x]++;
        fa[y] = x;
        dfs1(y);
        sz[x] += sz[y];
        if(sz[y] > maxx) {
            maxx = sz[y];
            son[x] = y;
        }
    }
}

void dfs2(int x, int top) {
    tp[x] = top;
    if(son[x]) dfs2(son[x], top);
    for(register int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if(y == son[x] || y == fa[x]) continue;
        dfs2(y, y);
    }
}

inline int lca(int x, int y) {
    int tp1 = tp[x], tp2 = tp[y];
    while(tp1 != tp2) {
        if(d[tp1] > d[tp2]) {
            swap(tp1, tp2);
            swap(x, y);
        }
        y = fa[tp2];
        tp2 = tp[y];
    }
    return d[x] < d[y] ? x : y;
}

int main() {
    int n = read(), m = read(), s = read();
    for(register int i = 1; i < n; i++) {
        int x = read(), y = read();
        add(x, y);
        add(y, x);
    }
    dfs1(s);
    dfs2(s, s);
    for(register int i = 1; i <= m; i++) {
        write(lca(read(), read()));
    }
    return 0;
}

处理子树

考虑一棵树的dfs

dfs序

用数组dfn[x]表示节点x的dfs序,即对根节点dfs时访问到它的顺序。

这个可以在dfs2()时维护:

void dfs2(int x, int top) {
    tp[x] = top;
    dfn[x] = ++idx;
    if(son[x]) dfs2(son[x], top);
    for(register int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if(y == son[x] || y == fa[x]) continue;
        dfs2(y, y);
    }
}

依然是剖分后的树,我们标记一下每个节点的dfs序:

可以发现,每个子树中的节点的dfs序是连续的。

我们用区间来表示一颗子树:

清楚多了。

可以发现每个节点区间的左端点为dfn[x],右端点为dfn[x] + sz[x] - 1

这让我们联想到线段树

于是我们可以对整棵树建立一颗线段树,这样可以区间查询获得子树信息

例题:P3384 【模板】树链剖分

链的dfs序显然是连续的,同样放在线段树上就行了。

注意本题中子树大小将根节点计算入内,sz[x] 应初始化为1。

#include <cstdio>
#include <cctype>
#define LL long long

namespace FastIO {
    inline int read() {
        char ch = getchar(); int r = 0, w = 1;
        while(!isdigit(ch)) {if(ch == '-') w = -1; ch = getchar();}
        while(isdigit(ch)) {r = r * 10 + ch - '0', ch = getchar();}
        return r * w;
    }
    void _write(int x) {
        if(x < 0) putchar('-'), x = -x;
        if(x > 9) _write(x / 10);
        putchar(x % 10 + '0');
    }
    inline void write(int x) {
        _write(x);
        puts("");
    }
}

using namespace FastIO;

template <typename T> inline void swap(T &x, T &y) {T tmp = x; x = y, y = tmp;}

const int N = 1e5 + 6;
const int M = N << 1;

int n, m, r, p, ptn[N], w[N];

// 线段树

#define lson (o << 1)
#define rson (o << 1) | 1
#define mid ((l + r) >> 1)

LL val[N << 2], tag[N << 2];

inline void pushdown(int o, int l, int r) {
    if(tag[o]) {
        tag[lson] = (tag[lson] + tag[o]) % p;
        tag[rson] = (tag[rson] + tag[o]) % p;
        val[lson] = (val[lson] + tag[o] * (mid - l + 1) % p) % p;
        val[rson] = (val[rson] + tag[o] * (r - mid) % p) % p;
        tag[o] = 0;
    }
}

inline void upd(int o) {
    val[o] = (val[lson] + val[rson]) % p;
}

inline void build(int o, int l, int r) {
    if(l == r) {
        val[o] = w[ptn[l]];
        return;
    }
    build(lson, l, mid);
    build(rson, mid + 1, r);
    upd(o);
}

inline void add(int o, int l, int r, int ll, int rr, LL v) {
    if(l >= ll && r <= rr) {
        val[o] = (val[o] + v * (r - l + 1) % p) % p;
        tag[o] = (tag[o] + v) % p;
        return;
    }
    pushdown(o, l, r);
    if(ll <= mid) add(lson, l, mid, ll, rr, v);
    if(rr > mid) add(rson, mid + 1, r, ll, rr, v);
    upd(o);
}

inline LL squery(int o, int l, int r, int ll, int rr) {
    if(l >= ll && r <= rr) {
        return val[o];
    }
    pushdown(o, l, r);
    LL res = 0;
    if(ll <= mid) res = (res + squery(lson, l, mid, ll, rr)) % p;
    if(rr > mid) res = (res + squery(rson, mid + 1, r, ll, rr)) % p;
    return res;
}

// 树剖
int head[N], nxt[M], ver[M], cnt;
int fa[N], sz[N], d[N], tp[N], son[N], dfn[N], idx;

inline void adde(int x, int y) {
    ver[++cnt] = y, nxt[cnt] = head[x], head[x] = cnt;
}

void dfs1(int x) {
    sz[x] = 1;
    int maxx = -1;
    for(register int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if(y == fa[x]) continue;
        d[y] = d[x] + 1;
        fa[y] = x;
        dfs1(y);
        sz[x] += sz[y];
        if(sz[y] > maxx) {
            maxx = sz[y];
            son[x] = y;
        }
    }
}

void dfs2(int x, int top) {
    dfn[x] = ++idx;
    ptn[dfn[x]] = x;
    tp[x] = top;
    if(son[x]) dfs2(son[x], top);
    for(register int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if(y == son[x] || y == fa[x]) continue;
        dfs2(y, y);
    }
}

inline void Addline(int x, int y, LL z) {
    int tx = tp[x], ty = tp[y];
    while(tx != ty) {
        if(d[tx] > d[ty]) {
            swap(tx, ty);
            swap(x, y);
        }
        add(1, 1, n, dfn[ty], dfn[y], z);
        y = fa[ty];
        ty = tp[y];
    }
    if(d[x] > d[y]) swap(x, y);
    add(1, 1, n, dfn[x], dfn[y], z);
}

inline LL Queryline(int x, int y) {
    int tx = tp[x], ty = tp[y];
    LL res = 0;
    while(tx != ty) {
        if(d[tx] > d[ty]) {
            swap(tx, ty);
            swap(x, y);
        }
        res = (res + squery(1, 1, n, dfn[ty], dfn[y])) % p;
        y = fa[ty];
        ty = tp[y];
    }
    if(d[x] > d[y]) swap(x, y);
    res = (res + squery(1, 1, n, dfn[x], dfn[y])) % p;
    return res;
}

inline void Addt(int x, LL z) {
    add(1, 1, n, dfn[x], dfn[x] + sz[x] - 1, z);
}

inline LL Queryt(int x) {
    return squery(1, 1, n, dfn[x], dfn[x] + sz[x] - 1);
}

int op, x, y, z;

int main() {
    n = read(), m = read(), r = read(), p = read();
    for(register int i = 1; i <= n; i++) w[i] = read();
    for(register int i = 1; i < n; i++) {
        int x = read(), y = read();
        adde(x, y);
        adde(y, x);
    }
    dfs1(r);
    dfs2(r, r);
    build(1, 1, n);
    while(m--) {
        op = read();
        x = read();
        if(op <= 2) y = read();
        if(op == 1 || op == 3) z = read();
        switch (op) {
            case 1:
                Addline(x, y, z);
                break;
            case 2:
                write(Queryline(x, y));
                break;
            case 3:
                Addt(x, z);
                break;
            case 4:
                write(Queryt(x));
                break;
        }
    }
    return 0;
}