dfs 序基础 + 树上差分

· · 个人记录

dfs 序

dfs 序就是到达和离开每个节点时的时间戳 s_ue_u

对于一颗以 u 节点为根的树,其子节点的时间戳 s_ve_v 是肯定小于根节点的时间戳 s_ue_u 的。

所以我们可以把 s_ue_u 看成一段区间 [s_u,e_u],其子节点的时间戳 s_v, e_v \in [s_u, e_u]

接着我们就可以对此区间进行操作,例如修改、查询……

T664766 dfs序基础1

同上,加一个树状数组维护即可。

#include <bits/stdc++.h>
#define int long long

using namespace std;

const int N = 1e6 + 10;

int n, m, r, a[N], s[N], e[N], tree[N << 2], tim;
vector<int> nbr[N];

int lowbit(int x) {
    return x & (-x);
}

void modify(int x, int y) {
    while (x <= n) {
        tree[x] += y;
        x += lowbit(x);
    }
    return;
}

int query(int x) {
    int sum = 0;
    while (x) {
        sum += tree[x];
        x -= lowbit(x);
    }
    return sum;
}

void dfs(int u, int fa) {
    s[u] = ++tim;
    modify(s[u], a[u]);
    for (int v : nbr[u]) {
        if (v != fa) {
            dfs(v, u);
        }
    }
    e[u] = tim;
}

signed main() {
    cin.tie(0), cout.tie(0)->sync_with_stdio(false);
    cin >> n >> m >> r;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    }    
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        nbr[u].push_back(v);
        nbr[v].push_back(u);
    }
    dfs(r, 0);
    while (m--) {
        int op;
        cin >> op;
        if (op == 1) {
            int a, x;
            cin >> a >> x;
            modify(s[a], x);
        } else {
            int a;
            cin >> a;
            cout << query(e[a]) - query(s[a] - 1) << '\n';
        }
    }
    return 0;
} 

T664767 dfs序基础2

套个线段树,记录 num_{tim} 为时间戳为 tim 经过的点,然后把线段树里面下标换成 num 就可以了。

#include <bits/stdc++.h>
#define int long long
#define lson(x) ((x << 1))
#define rson(x) ((x << 1) | 1)

using namespace std;

const int N = 1e6 + 10;

int n, m, r, a[N], s[N], e[N], diff[N], tree[N << 2], tim, num[N], lazy[N << 2];
vector<int> nbr[N];

void push_up(int k) {
    tree[k] = tree[lson(k)] + tree[rson(k)];
    return;
}

void update_son(int k, int l, int r, int v) {
    lazy[k] += v;
    tree[k] += (r - l + 1) * v;
    return;
} 

void push_down(int k, int l, int r) {
    if (lazy[k] == 0) {
        return;
    }
    int mid = l + r >> 1;
    update_son(lson(k), l, mid, lazy[k]);
    update_son(rson(k), mid + 1, r, lazy[k]);
    lazy[k] = 0;
    return;
}

void build(int k, int l, int r) {
    lazy[k] = 0;
    if (l == r) {
        tree[k] = a[num[l]];
        return;
    }
    int mid = l + r >> 1;
    build(lson(k), l, mid);
    build(rson(k), mid + 1, r);
    push_up(k);
    return;
}

int query(int L, int R, int l, int r, int k) {
    if (L <= l && r <= R) {
        return tree[k];
    }
    int sum = 0, mid = l + r >> 1;
    push_down(k, l, r);
    if (L <= mid) {
        sum += query(L, R, l, mid, lson(k));
    } 
    if (R > mid) {
        sum += query(L, R, mid + 1, r, rson(k));
    }
    return sum;
}

void update(int L, int R, int v, int l, int r, int k) {
    if (L <= l && r <= R) {
        update_son(k, l, r, v);
        return;
    }
    int mid = l + r >> 1;
    push_down(k, l, r);
    if (L <= mid) {
        update(L, R, v, l, mid, lson(k));
    }
    if (R > mid) {
        update(L, R, v, mid + 1, r, rson(k));
    }
    push_up(k);
    return;
}

void dfs(int u, int fa) {
    s[u] = ++tim;
    num[tim] = u;
    for (int v : nbr[u]) {
        if (v != fa) {
            dfs(v, u);
        }
    }
    e[u] = tim;
}

signed main() {
    cin.tie(0), cout.tie(0)->sync_with_stdio(false);
    cin >> n >> m >> r;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    }    
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        nbr[u].push_back(v);
        nbr[v].push_back(u);
    }
    dfs(r, 0);
    build(1, 1, n);
    while (m--) {
        int op;
        cin >> op;
        if (op == 1) {
            int a, x;
            cin >> a >> x;
            update(s[a], e[a], x, 1, n, 1);
        } else {
            int a;
            cin >> a;
            cout << query(s[a], e[a], 1, n, 1) << '\n';
        }
    }
    return 0;
} 

P2982 [USACO10FEB] Slowing down G

很容易发现可以对树进行差分,然后就很好统计了。

#include <bits/stdc++.h>
#define int long long

using namespace std;

const int N = 1e6 + 10;

int n, m, r, a[N], s[N], e[N], tree[N << 2], tim;
vector<int> nbr[N];

int lowbit(int x) {
    return x & (-x);
}

void modify(int x, int y) {
    while (x <= n) {
        tree[x] += y;
        x += lowbit(x);
    }
    return;
}

int query(int x) {
    int sum = 0;
    while (x) {
        sum += tree[x];
        x -= lowbit(x);
    }
    return sum;
}

void dfs(int u, int fa) {
    s[u] = ++tim;
    modify(s[u], a[u]);
    for (int v : nbr[u]) {
        if (v != fa) {
            dfs(v, u);
        }
    }
    e[u] = tim;
}

signed main() {
    cin.tie(0), cout.tie(0)->sync_with_stdio(false);
    cin >> n;
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        nbr[u].push_back(v);
        nbr[v].push_back(u);
    }
    dfs(1, 0);
    for (int i = 1; i <= n; i++) {
        int x;
        cin >> x;
        cout << query(s[x]) << '\n';
        modify(s[x], 1);
        modify(e[x] + 1, -1);
    }
    return 0;
} 

P3128 [USACO15DEC] Max Flow P

依旧树上差分。对于路径 (u,v)diff_u+1,diff_y+1,diff_{lca}-1,diff_{lca父亲}-1,再 dfs 就好了。

#include <bits/stdc++.h>
#define int long long

using namespace std;

const int N = 3e5 + 10;

int n, m, dp[N][25], siz[N], diff[N], sum[N];
vector<int> nbr[N];

void dfs1(int u, int fa) {
    siz[u] = siz[fa] + 1;
    dp[u][0] = fa;
    for (int i = 1; (1 << i) <= siz[u]; i++) {
        dp[u][i] = dp[dp[u][i - 1]][i - 1];
    }
    for (int v : nbr[u]) {
        if (v != fa) {
            dfs1(v, u);
        }
    }
}

int lca(int u, int v) {
    if (siz[u] > siz[v])
        swap(u, v);
    for (int i = 20; i >= 0; i--) {
        if (siz[dp[v][i]] >= siz[u])
            v = dp[v][i];
    }
    if (u == v)
        return u;
    for (int i = 20; i >= 0; i--) {
        if (dp[v][i] != dp[u][i])
            u = dp[u][i], v = dp[v][i];
    }
    return dp[u][0];
}

void dfs2(int u, int fa) {
    sum[u] = diff[u];
    for (int v : nbr[u]) {
        if (v != fa) {
            dfs2(v, u);
            sum[u] += sum[v];
        }
    }
}

signed main() {
    cin.tie(0), cout.tie(0)->sync_with_stdio(false);
    cin >> n >> m;
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        nbr[u].push_back(v);
        nbr[v].push_back(u);
    }
    siz[1] = 1;
    dfs1(1, 0);
    while (m--) {
        int x, y;
        cin >> x >> y;
        int LCA = lca(x, y);
        int fa = dp[LCA][0];
        diff[x]++, diff[y]++, diff[LCA]--, diff[fa]--;
    }
    dfs2(1, 0);
    int maxn = -1e9;
    for (int i = 1; i <= n; i++) {
        maxn = max(maxn, sum[i]);
    }
    cout << maxn;
    return 0;
}