题解:P4719 【模板】动态 DP
动态 dp 主要运用于树上或者区间上的 dp 问题,核心是将本来的转移过程写成矩阵转移的形式,构造矩阵
然后呢,我们对于这个
下面看例题。
不妨先考虑没有修改怎么做,套路的考虑树形 dp,设
这个时候我们再引入修改操作,由于每个点的
但是发现信息不是特别好维护,可以定义一个辅助修改数组
那么转移就变成了
发现很符合矩阵的样子是吧,那么可以定义广义矩阵乘法
那么用线段树维护这个转移矩阵,再用树剖维护链上的,统计答案就是简单的了。修改时候只需要对每个链上可能会修改的转移矩阵修改一下差值即可。
注意转移过程是由深入浅的,所以线段树得倒着乘。
const int N = 1e5 + 19, mod = 998244353, inf = -1e18;
int n, m, f[N][2], g[N][2], fa[N], dfn[N], cnt, top[N], siz[N], son[N], dep[N], a[N], deepest[N], rnk[N];
vector<int> G[N];
struct Matrix {
int a[2][2];
Matrix() { for (int i:{0,1}) for (int j:{0,1}) a[i][j]=0; }
Matrix operator * (const Matrix e) {
Matrix ans;
for (int i:{0,1}) for (int j:{0,1}) for (int k:{0,1}) smax(ans.a[i][j],a[i][k]+e.a[k][j]);
return ans;
}
void out() {
puts("Matrix:");
for (int i:{0,1}) for (int j:{0,1}) write(a[i][j]," \n"[j==1]);
}
} p[N];
void dfs(int u, int f) {
dep[u]=dep[f]+1; siz[u]=1; fa[u]=f;
for (int v:G[u]) {
if (v==f) continue;
dfs(v,u); siz[u]+=siz[v]; if (siz[v]>siz[son[u]]) son[u]=v;
}
}
void dfs2(int u,int tp) {
dfn[u]=++cnt; top[u]=tp; deepest[u]=u; rnk[cnt]=u;
if (son[u]) {
dfs2(son[u],tp); deepest[u]=deepest[son[u]];
}
g[u][1]+=a[u];
for (int v:G[u]) {
if (v!=fa[u]&&v!=son[u]) {
dfs2(v,v);
g[u][0]+=max(f[v][0],f[v][1]);
g[u][1]+=max(f[v][0],inf);
}
}
f[u][0]=max(f[son[u]][0],f[son[u]][1])+g[u][0];
f[u][1]=f[son[u]][0]+g[u][1];
// printf("Id=%4lld G: %lld %lld\n", u,g[u][0],g[u][1]);
p[u].a[0][0]=p[u].a[1][0]=g[u][0]; p[u].a[0][1]=g[u][1]; p[u].a[1][1]=inf;
}
void add(int &a, int b) {
if ((a += b) >= mod) a -= mod;
}
struct SegTree {
Matrix t[N << 2];
#define ls (k << 1)
#define rs (k << 1 | 1)
#define mid ((l + r) >> 1)
void build(int k, int l, int r) {
if (l == r) return t[k] = p[rnk[l]], void();
build(ls, l, mid); build(rs, mid+1, r);
t[k] = t[rs] * t[ls];
}
void update(int k, int l, int r, int s) {
if (l == r) return t[k] = p[rnk[l]], void();
if (s <= mid) update(ls, l, mid, s);
else update(rs, mid+1, r, s);
t[k] = t[rs] * t[ls];
}
Matrix query(int k, int l, int r, int L, int R) {
if (L <= l && r <= R) return t[k];
if (L > mid) return query(rs, mid+1, r, L, R);
if (R <= mid) return query(ls, l, mid, L, R);
return query(rs, mid+1, r, L, R) * query(ls, l, mid, L, R);
}
} seg;
void modify(int u, int w) {
g[u][1]+=w-a[u]; a[u]=w;
while (u) {
Matrix A=seg.query(1,1,n,dfn[top[u]],dfn[deepest[top[u]]]);
p[u].a[0][0]=p[u].a[1][0]=g[u][0]; p[u].a[0][1]=g[u][1]; p[u].a[1][1]=inf;
seg.update(1,1,n,dfn[u]);
Matrix B=seg.query(1,1,n,dfn[top[u]],dfn[deepest[top[u]]]);
int v=top[u];
g[fa[v]][0]-=max(A.a[0][0],A.a[0][1]);
g[fa[v]][0]+=max(B.a[0][0],B.a[0][1]);
g[fa[v]][1]-=A.a[0][0];
g[fa[v]][1]+=B.a[0][0];
u = fa[top[u]];
}
}
signed main() {
read(n, m);
for (int i=1;i<=n;++i) read(a[i]);
for (int i=1,u,v;i<n;++i) {
read(u,v); G[u].eb(v); G[v].eb(u);
} dfs(1,0); dfs2(1,1); seg.build(1,1,n);
// (p[2]*p[5]).out();
for (int i=1;i<=m;++i) {
int x,y; read(x,y);
modify(x,y);
Matrix ret=seg.query(1,1,n,dfn[1],dfn[deepest[1]]);
write(max(ret.a[0][1],ret.a[0][0]));
}
return 0;
}