树链剖分

· · 个人记录

树链剖分——毒瘤数据结构

众所周知, 树链剖分这玩意很毒瘤, 同时可以证明出题人很毒瘤

150+代码敲死我

#include<bits/stdc++.h>
using namespace std;
#define ri register int
#define mem(a,b) memset(a, b, sizoef(a))
#define long long ll;
#define mid ((l + r) >> 1)
#define lson rt << 1 , l, mid
#define rson rt << 1 | 1 , mid + 1, r
#define len (r - l + 1)
template<class T>
inline void read(T &x){
    int f = 0 , ch = 0; x = 0;
    for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = 1;
    for (; isdigit(ch); ch = getchar()) x = x * 10 + ch - '0';
    if (f) x = -x;
}
const int MAXN = 2000009;
struct Node
{
    int nxt, to;
}e[MAXN];
int n, m, r, mod, cnt, head[MAXN];
int a[MAXN << 2], lazy[MAXN << 2];
int son[MAXN], id[MAXN], fa[MAXN], tot, dep[MAXN], siz[MAXN], top[MAXN], wt[MAXN], w[MAXN];
int res = 0, tim;
inline void addedge(int x, int y){
    e[++cnt].to = y;
    e[cnt].nxt = head[x];
    head[x] = cnt;
}

inline void pushdown(int rt, int lenn)
{
    lazy[rt << 1] += lazy[rt];
    lazy[rt << 1 | 1] += lazy[rt];
    a[rt << 1] += lazy[rt] * (lenn - (lenn >> 1));
    a[rt << 1 | 1] += lazy[rt] * (lenn >> 1);
    a[rt << 1] %= mod;
    a[rt << 1 | 1] %= mod;
    lazy[rt]=0;
}
void build(int rt, int l, int r){
    if (l == r){
        a[rt] = wt[l];
        a[rt] %= mod;
        return ;
    }
    build(lson);build(rson);
    a[rt] = (a[rt << 1] + a[rt << 1 | 1]) % mod;
}
void query(int rt, int l, int r, int L, int R)
{
    if (L <= l && r <= R){
        res += a[rt];
        res %= mod;
        return;
    }
    else{
        if (lazy[rt]) pushdown(rt, len);
        if (L <= mid) query(lson, L, R);
        if (R > mid) query(rson, L, R);
    }
}
inline void update(int rt, int l, int r, int L, int R, int k)
{
    if (L <= l && r <= R){ lazy[rt] += k; a[rt] += k * len;}
    else{
        if (lazy[rt]) pushdown(rt, len);
        if (L <= mid) update(lson, L, R, k);
        if (R > mid) update(rson, L, R, k);
        a[rt] = (a[rt << 1] + a[rt << 1 | 1]) % mod;
    }
}

// xianduanshu checked
inline int qrange(int x, int y)
{
    int ans = 0;
    while (top[x] != top[y]){
        if (dep[top[x]] < dep[top[y] ]) swap(x, y);
        res = 0;
        query(1, 1, n, id[top[x] ], id[x]);
        ans += res;ans %= mod;
        x = fa[top[x]];
    }
    if (dep[x] > dep[y]) swap(x, y);
    res = 0;
    query(1, 1, n, id[x], id[y]);
    ans += res;
    return ans % mod;
}
inline void update_range(int x, int y, int k){
    k %= mod;
    while (top[x] != top[y]){
        if (dep[top[x]] < dep[top[y] ]) swap(x, y);
        update(1, 1, n, id[top[x] ], id[x], k);
        x = fa[top[x]];
    }
    if (dep[x] > dep[y]) swap(x, y);
    update(1, 1, n, id[x], id[y], k);
}
inline int qson(int x){
    res = 0;
    query(1, 1, n, id[x], id[x] + siz[x] - 1);
    return res;
}
inline void update_son(int x, int k){
    update(1, 1, n, id[x], id[x] + siz[x] - 1, k);
}

void dfs1(int x, int f, int deep){
    dep[x] = deep;fa[x] = f; siz[x] = 1;
    int maxson = -1;
    for (ri i = head[x]; i ; i = e[i].nxt){
        int v = e[i].to;
        if (v == f) continue;
        dfs1(v, x, deep + 1);
        siz[x] += siz[v];
        if (siz[v] > maxson) son[x] = v, maxson = siz[v];
    }
}
void dfs2(int x, int topf)
{
    id[x] = ++tim; wt[tim] = w[x];
    top[x] = topf;
    if (!son[x]) return;
    dfs2(son[x], topf);
    for (ri i = head[x]; i; i = e[i].nxt){
        int y = e[i].to;
        if (y == fa[x] || y == son[x])
            continue;
        dfs2(y, y);
    }
}

int main()
{
    read(n), read(m), read(r), read(mod);
    for (ri i = 1; i <= n; ++i) read(w[i]);
    for (ri i = 1; i < n; i++){
        int x, y;
        read(x), read(y); addedge(x, y);addedge(y, x);
    }

    tim = 0;
    dfs1(r, 0, 1); dfs2(r, r);build(1, 1, n);
    while (m--)
    {
        int k, x, y, z; read(k);
        if (k == 1){
            read(x), read(y), read(z);
            update_range(x, y, z);
        }
        if (k == 2){
            read(x), read(y);
            printf("%d\n", qrange(x, y));
        }
        if (k == 3){
            read(x), read(y);
            update_son(x, y);
        }
        if (k == 4){
            read(x);
            printf("%d\n", qson(x));
        }
    }
    return 0;
}

算法分析:

首先这要是颗树 用来求:

表示将树从x到y结点最短路径上所有节点的值都加上z

表示求树从x到y结点最短路径上所有节点的值之和

表示将以x为根节点的子树内所有节点值都加上z

表示求以x为根节点的子树内所有节点值之和

1.找出重儿子,以及重链的起点, 并以其为优先做dfs

 // 找重儿子
 void dfs1(int x, int f, int deep){
    dep[x] = deep;fa[x] = f; siz[x] = 1;
    int maxson = -1;
    for (ri i = head[x]; i ; i = e[i].nxt){
        int v = e[i].to;
        if (v == f) continue;
        dfs1(v, x, deep + 1);
        siz[x] += siz[v];
        if (siz[v] > maxson) son[x] = v, maxson = siz[v];
    }
}
// 利用时间戳对其标记,重儿子优先
void dfs2(int x, int topf)
{
    id[x] = ++tim; wt[tim] = w[x];
    top[x] = topf;
    if (!son[x]) return;
    dfs2(son[x], topf);
    for (ri i = head[x]; i; i = e[i].nxt){
        int y = e[i].to;
        if (y == fa[x] || y == son[x])
            continue;
        dfs2(y, y);
    }
}

2. 以重链进行构建线段树操作(QAQ,欺负我手残老是打错线段树)

 inline void pushdown(int rt, int lenn)
{
    lazy[rt << 1] += lazy[rt];
    lazy[rt << 1 | 1] += lazy[rt];
    a[rt << 1] += lazy[rt] * (lenn - (lenn >> 1));
    a[rt << 1 | 1] += lazy[rt] * (lenn >> 1);
    a[rt << 1] %= mod;
    a[rt << 1 | 1] %= mod;
    lazy[rt]=0;
}
void build(int rt, int l, int r){
    if (l == r){
        a[rt] = wt[l];
        a[rt] %= mod;
        return ;
    }
    build(lson);build(rson);
    a[rt] = (a[rt << 1] + a[rt << 1 | 1]) % mod;
}
void query(int rt, int l, int r, int L, int R)
{
    if (L <= l && r <= R){
        res += a[rt];
        res %= mod;
        return;
    }
    else{
        if (lazy[rt]) pushdown(rt, len);
        if (L <= mid) query(lson, L, R);
        if (R > mid) query(rson, L, R);
    }
}
inline void update(int rt, int l, int r, int L, int R, int k)
{
    if (L <= l && r <= R){ lazy[rt] += k; a[rt] += k * len;}
    else{
        if (lazy[rt]) pushdown(rt, len);
        if (L <= mid) update(lson, L, R, k);
        if (R > mid) update(rson, L, R, k);
        a[rt] = (a[rt << 1] + a[rt << 1 | 1]) % mod;
    }
}

3. 最后就是树上操作了

 inline int qrange(int x, int y)
{
    int ans = 0;// 这两个在同一个链上吗 不在那就跳到一条链上
    while (top[x] != top[y]){
        if (dep[top[x]] < dep[top[y] ]) swap(x, y);
        res = 0;
        query(1, 1, n, id[top[x] ], id[x]);
        ans += res;ans %= mod;
        x = fa[top[x]];
    }
    if (dep[x] > dep[y]) swap(x, y);
    res = 0;
    query(1, 1, n, id[x], id[y]);
    ans += res;
    return ans % mod;
}
inline void update_range(int x, int y, int k){
    k %= mod;
    while (top[x] != top[y]){// 同上
        if (dep[top[x]] < dep[top[y] ]) swap(x, y);
        update(1, 1, n, id[top[x] ], id[x], k);
        x = fa[top[x]];
    }
    if (dep[x] > dep[y]) swap(x, y);
    update(1, 1, n, id[x], id[y], k);
}
inline int qson(int x){
    res = 0;
    query(1, 1, n, id[x], id[x] + siz[x] - 1);// 查询他儿子
    return res;
}
inline void update_son(int x, int k){
    update(1, 1, n, id[x], id[x] + siz[x] - 1, k);// 对他的儿子进行更新
}