题解:P4719 【模板】动态 DP

· · 题解

动态 dp 主要运用于树上或者区间上的 dp 问题,核心是将本来的转移过程写成矩阵转移的形式,构造矩阵 g_i 以及广义矩阵乘法使得 f_j \times g_i = f_i

然后呢,我们对于这个 g_i,通过发掘原题的性质,让它在原做法中乘上若干个 g 变成只乘一个 g。换句话说呢就是需要维护一段连续的区间,在这个区间里提前进行转移(可以用线段树维护)。一般树上的做法可以是树链剖分变成 dfs 序再用线段树维护转移矩阵。做法因题目而异。

下面看例题。

不妨先考虑没有修改怎么做,套路的考虑树形 dp,设 f_{i,0/1} 表示 i 节点选不选,那么就有 f_{i,0}=\sum_{j\in son_i} \max(f_{j,0},f_{j,1}) 表示儿子选不选随意, f_{i,1}=a_i+\sum_{j\in son_i} f_{j,0} 表示儿子一定不能选。

这个时候我们再引入修改操作,由于每个点的 f 值只与他的子节点的值有关,所以我们发现它只会修改一条上的 f 值。如果考虑维护链上信息,那么就是需要引入树链剖分了。

但是发现信息不是特别好维护,可以定义一个辅助修改数组 g_{i,0}=\sum_{j\neq h_i} \max(f_{j,0},f_{j,1})g_{i,1}=a_i + \sum_{j\neq h_i} f_{j,0},其中 h_i 表示 i 的重儿子。

那么转移就变成了 f_{i,0}=\max(f_{h_i,0},f_{h_i,1})+g_{i,0}f_{i,1}=f_{h_i,0}+g_{i,1}

发现很符合矩阵的样子是吧,那么可以定义广义矩阵乘法 c_{i,j}=\max_k a_{i,k}+b_{k,j},那么就变成了

f_{h_i,0} & f_{h_i,1} \end{bmatrix} \times \begin{bmatrix} g_{i,0} & g_{i,1} \\ g_{i,0} & -\infty \end{bmatrix} = \begin{bmatrix} f_{i,0} & f_{i,1} \end{bmatrix}

那么用线段树维护这个转移矩阵,再用树剖维护链上的,统计答案就是简单的了。修改时候只需要对每个链上可能会修改的转移矩阵修改一下差值即可。

注意转移过程是由深入浅的,所以线段树得倒着乘。

const int N = 1e5 + 19, mod = 998244353, inf = -1e18;
int n, m, f[N][2], g[N][2], fa[N], dfn[N], cnt, top[N], siz[N], son[N], dep[N], a[N], deepest[N], rnk[N];
vector<int> G[N];
struct Matrix {
    int a[2][2];
    Matrix() { for (int i:{0,1}) for (int j:{0,1}) a[i][j]=0; }
    Matrix operator * (const Matrix e) {
        Matrix ans;
        for (int i:{0,1}) for (int j:{0,1}) for (int k:{0,1}) smax(ans.a[i][j],a[i][k]+e.a[k][j]);
        return ans;
    }
    void out() {
        puts("Matrix:");
        for (int i:{0,1}) for (int j:{0,1}) write(a[i][j]," \n"[j==1]);
    }
} p[N];
void dfs(int u, int f) {
    dep[u]=dep[f]+1; siz[u]=1; fa[u]=f;
    for (int v:G[u]) {
        if (v==f) continue;
        dfs(v,u); siz[u]+=siz[v]; if (siz[v]>siz[son[u]]) son[u]=v;
    }
}
void dfs2(int u,int tp) {
    dfn[u]=++cnt; top[u]=tp; deepest[u]=u; rnk[cnt]=u;
    if (son[u]) {
        dfs2(son[u],tp); deepest[u]=deepest[son[u]];
    }
    g[u][1]+=a[u];
    for (int v:G[u]) {
        if (v!=fa[u]&&v!=son[u]) {
            dfs2(v,v);
            g[u][0]+=max(f[v][0],f[v][1]);
            g[u][1]+=max(f[v][0],inf);
        }
    }
    f[u][0]=max(f[son[u]][0],f[son[u]][1])+g[u][0];
    f[u][1]=f[son[u]][0]+g[u][1];
//  printf("Id=%4lld    G: %lld %lld\n", u,g[u][0],g[u][1]);
    p[u].a[0][0]=p[u].a[1][0]=g[u][0]; p[u].a[0][1]=g[u][1]; p[u].a[1][1]=inf;
}
void add(int &a, int b) {
    if ((a += b) >= mod) a -= mod;
}
struct SegTree {
    Matrix t[N << 2];
    #define ls (k << 1)
    #define rs (k << 1 | 1)
    #define mid ((l + r) >> 1)
    void build(int k, int l, int r) {
        if (l == r) return t[k] = p[rnk[l]], void();
        build(ls, l, mid); build(rs, mid+1, r);
        t[k] = t[rs] * t[ls];
    }
    void update(int k, int l, int r, int s) {
        if (l == r) return t[k] = p[rnk[l]], void();
        if (s <= mid) update(ls, l, mid, s);
        else update(rs, mid+1, r, s);
        t[k] = t[rs] * t[ls];
    }
    Matrix query(int k, int l, int r, int L, int R) {
        if (L <= l && r <= R) return t[k];
        if (L > mid) return query(rs, mid+1, r, L, R);
        if (R <= mid) return query(ls, l, mid, L, R);
        return query(rs, mid+1, r, L, R) * query(ls, l, mid, L, R);
    }
} seg; 
void modify(int u, int w) {
    g[u][1]+=w-a[u]; a[u]=w;
    while (u) {
        Matrix A=seg.query(1,1,n,dfn[top[u]],dfn[deepest[top[u]]]);
        p[u].a[0][0]=p[u].a[1][0]=g[u][0]; p[u].a[0][1]=g[u][1]; p[u].a[1][1]=inf;
        seg.update(1,1,n,dfn[u]);
        Matrix B=seg.query(1,1,n,dfn[top[u]],dfn[deepest[top[u]]]);
        int v=top[u];
        g[fa[v]][0]-=max(A.a[0][0],A.a[0][1]);
        g[fa[v]][0]+=max(B.a[0][0],B.a[0][1]);
        g[fa[v]][1]-=A.a[0][0];
        g[fa[v]][1]+=B.a[0][0];
        u = fa[top[u]];
    }
}
signed main() {
    read(n, m);
    for (int i=1;i<=n;++i) read(a[i]);
    for (int i=1,u,v;i<n;++i) {
        read(u,v); G[u].eb(v); G[v].eb(u);
    } dfs(1,0); dfs2(1,1); seg.build(1,1,n);
//  (p[2]*p[5]).out();
    for (int i=1;i<=m;++i) {
        int x,y; read(x,y);
        modify(x,y);
        Matrix ret=seg.query(1,1,n,dfn[1],dfn[deepest[1]]);
        write(max(ret.a[0][1],ret.a[0][0]));
    }
    return 0;
}