DDP(动态dp)学习笔记

· · 算法·理论

动态DP

create:2024.11.16

P4719 【模板】动态 DP

题目链接

1.正常思路

一个很显然的转移是:

f_{i,0}=\sum _{son} max(f_{son,0},f_{son,1}) f_{i,1}=a_i+\sum _{son} f_{son,0}

其中 f_{i,0/1} 表示第 i 个节点选 / 不选时子树内的答案

但是,这只能解决静态的树上最大权独立集问题

如果我们要带修呢? 动态dp就出场了

我们可以先对树进行重链剖分 得到每个点的重儿子和轻儿子

我们设 g_{i,0/1} 表示第 i 个节点选 / 不选时轻儿子的答案

转移变为:(其中 son 表示节点 i 的重儿子)

f_{i,0}=g_{i,0}+max(f_{son,0},f_{son,1}) f_{i,1}=g_{i,1}+f_{son,0}

2.广义矩阵乘法

上述转移让我们很想用矩阵乘法的形式进行转移,但是转移中涉及了 max 操作,所以我们先看看普通的矩阵乘法

C=A\times B \ C_{i,j}=\sum _{k=1}^n A_{i,k}\times B_{k,j}

若我们将乘法改为加法,加法改为 max ,即为广义矩阵乘法:

C_{i,j}=\max _{k=1}^n \{ A_{i,k}+B_{k,j}\}

我们惊奇地发现,这个广义矩阵乘法(以下简称广义矩乘)也满足结合律 (请读者自证)

其单位矩阵为(以2×2为例):\begin{bmatrix}0&-\infty\\-\infty&0\end{bmatrix}

前面的转移为 :

\begin{bmatrix}g_{i,0}&g_{i,0}\\g_{i,1}&-\infty\end{bmatrix}·\begin{bmatrix}f_{son,0}\\f_{son,1}\end{bmatrix}=\begin{bmatrix}f_{i,0}\\f_{i,1}\end{bmatrix}

3.结合线段树

前面我们对树进行了树链剖分的操作,那我们就要充分利用轻重儿子的性质

既然广义矩乘满足结合律,那么我们可以使用线段树维护转移矩阵,从而加速修改和查询过程

4.完整步骤

预处理:

  1. 树链剖分并 dfs 预处理出每个点的 f_{i,0/1}g_{i,0/1} 的值 (树链剖分需要多维护一个值 end 表示链底)

  2. 按照dfs序建线段树维护转移矩阵 (其中叶子节点的值为单位矩阵)

查询点 x 所在链的答案:

  1. 用矩阵 T 左乘矩阵 \begin{bmatrix}0\\a_{end[x]}\end{bmatrix} 得到答案矩阵 \begin{bmatrix}f_{top[x],0}\\f_{top[x],1}\end{bmatrix}

  2. ans=max(f_{top[x],0},f_{top[x],1})

修改点 x 的权值为 y

  1. old 为节点 x 权值修改前所在链的答案矩阵

  2. 更改节点 x 的转移矩阵为 \begin{bmatrix}g_{x,0}&g_{x,0}\\g_{x,1}-a_x+y&-\infty\end{bmatrix}

  3. new 为节点 x 权值修改后所在链的答案矩阵

  4. fa[top[x]] 不为 0 ,则让 x=fa[top[x]]

  5. 修改节点 x 的转移矩阵为 \begin{bmatrix}g_{x,0}-max(old)+max(new)&g_{x,0}-max(old)+max(new)\\g_{x,1}-old[0]+new[0]&-\infty\end{bmatrix}

  6. 重复上述步骤(步骤 2 除外)

5.代码

#include <bits/stdc++.h>
#define getmid int mid=(l+r)>>1
#define ls (i<<1)
#define rs (i<<1|1)
using namespace std;
const int N = 1e5+5, MIN = -1e9;
struct EDGE {
    int v, nxt;
} e[N << 1];
struct Matrix {
    int ma[2][2];
    void f(int a1, int a2, int b1, int b2) {
        ma[0][0] = a1, ma[0][1] = a2, ma[1][0] = b1, ma[1][1] = b2;
    }
    Matrix operator*(const Matrix& a) {
        Matrix c;
        c.ma[0][0] = max(ma[0][0] + a.ma[0][0], ma[0][1] + a.ma[1][0]);
        c.ma[0][1] = max(ma[0][0] + a.ma[0][1], ma[0][1] + a.ma[1][1]);
        c.ma[1][0] = max(ma[1][0] + a.ma[0][0], ma[1][1] + a.ma[1][0]);
        c.ma[1][1] = max(ma[1][0] + a.ma[0][1], ma[1][1] + a.ma[1][1]);
        return c;
    }
} base, trans[N], old, ne, tre[N << 2];
int n, m, h[N], cnt, a[N], f[N][2], g[N][2];
int fa[N], son[N], dep[N], siz[N], top[N], down[N], dfn[N], id[N], dfncnt;
void add_edge(int u, int v) {
    e[cnt] = {v, h[u]};
    h[u] = cnt++;
}
void dfs1(int u) {
    siz[u] = 1, son[u] = -1;
    for (int i = h[u]; ~i; i = e[i].nxt) {
        int v = e[i].v;
        if (v != fa[u]) {
            fa[v] = u, dep[v] = dep[u] + 1;
            dfs1(v);
            siz[u] += siz[v];
            if (son[u] == -1 || siz[v] > siz[son[u]]) son[u] = v;
        }
    }
}
void dfs2(int u, int t) {
    top[u] = t, dfn[u] = ++dfncnt, id[dfncnt] = u;
    if (son[u] == -1) {
        down[u] = u;
        return;
    }
    dfs2(son[u], t), down[u] = down[son[u]];
    for (int i = h[u]; ~i; i = e[i].nxt) {
        int v = e[i].v;
        if (v != fa[u] && v != son[u]) dfs2(v, v);
    }
}
void dfs3(int u) {
    f[u][1] = a[u];
    for (int i = h[u]; ~i; i = e[i].nxt) {
        int v = e[i].v;
        if (v != fa[u]) {
            dfs3(v);
            f[u][0] += max(f[v][0], f[v][1]);
            f[u][1] += f[v][0];
        }
    }
    if (son[u] == -1) return;
    g[u][0] = f[u][0] - max(f[son[u]][0], f[son[u]][1]);
    g[u][1] = f[u][1] - f[son[u]][0] - a[u];
}
void build(int l = 1, int r = n, int i = 1) {
    if (l == r) {
        tre[i] = trans[id[l]];
        return;
    }
    getmid;
    build(l, mid, ls), build(mid + 1, r, rs);
    tre[i] = tre[ls] * tre[rs];
}
void modify(int a, Matrix p, int l = 1, int r = n, int i = 1) {
    if (a == l && a == r) {
        tre[i] = p;
        return;
    }
    getmid;
    if (a <= mid) modify(a, p, l, mid, ls);
    else modify(a, p, mid + 1, r, rs);
    tre[i] = tre[ls] * tre[rs];
}
Matrix query(int a, int b, int l = 1, int r = n, int i = 1) {
    if (a <= l && r <= b) return tre[i];
    getmid;
    Matrix ans = base;
    if (mid + 1 <= b) ans = query(a, b, mid + 1, r, rs) * ans;
    if (a <= mid) ans = query(a, b, l, mid, ls) * ans;
    return ans;
}
Matrix query_l(int x) {
    Matrix t;
    t.f(0, MIN, a[down[x]], MIN);
    t = query(dfn[top[x]], dfn[down[x]]) * t;
    return t;
}
signed main() {
    ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
    memset(h, -1, sizeof(h));
    base.f(0, MIN, MIN, 0);
    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> a[i];
    for (int i = 1, u, v; i < n; i++) cin >> u >> v, add_edge(u, v), add_edge(v, u);
    dfs1(1), dfs2(1, 1), dfs3(1);
    for (int i = 1; i <= n; i++) {
        if (son[i] == -1) trans[i] = base;
        else trans[i].f(g[i][0], g[i][0], g[i][1] + a[i], MIN);
    }
    build();
    for (int i = 1, x, y; i <= m; i++) {
        cin >> x >> y;
        (trans[x].ma[1][0] -= a[x]) += y;
        old = query_l(x);
        a[x] = y;
        modify(dfn[x], trans[x]);
        ne = query_l(x);
        while (fa[top[x]] != 0) {
            x = fa[top[x]];
            (trans[x].ma[0][0] -= max(old.ma[0][0], old.ma[1][0])) += max(ne.ma[0][0], ne.ma[1][0]);
            trans[x].ma[0][1] = trans[x].ma[0][0];
            (trans[x].ma[1][0] -= old.ma[0][0]) += ne.ma[0][0];
            old = query_l(x);
            modify(dfn[x], trans[x]);
            ne = query_l(x);
        }
        Matrix ANS = query_l(1);
        cout << max(ANS.ma[0][0], ANS.ma[1][0]) << "\n";
    }
    return 0;
}

时间复杂度:O(2^3mlog^2n)