【学习·文化课】中考复习十四:树上换根背包(qoj4815 sol 2;CF2229I)

· · 算法·理论

【树上换根背包:qoj4815 sol 2;CF2229I】前置知识:换根 DP;树上背包及其复杂度分析。

这个东西太有趣了,学习。

理论部分

假设我们要求在一棵树上恰好选 K 个点,然后求一定的代价,对于每个点为根都要这么算。

对于单个根,先以下面的树上背包为例。

void dfs(int u,int p){
  sz[u]=1;
  for (auto v:G[u])if (v^p){
    dfs(v,u);
    memset(tmp,0,sizeof(tmp));
    rep(i,0,min(m,sz[u]))rep(j,0,min(m-i,sz[v]))
      tmp[i+j]=max(tmp[i+j],f[u][i]+f[v][j]);
    memcpy(f[u],tmp,sizeof(f[u]));
    sz[u]+=sz[v];
  }
  // Do something to f
}

首先,我们知道,像上面一样对一棵 n 点有根树进行值域为 k 的树上背包复杂度为 \Theta(nk)。因而我们换根背包的理论复杂度下界就是 \Theta(nk)

能不能做到这个复杂度下界呢?我们先考虑想一点有拓展性的暴力。考虑像所有换根一样,多定义一个 g(u,i) 表示在 u 的子树外面(不包括 u)选了 i 个点的贡献,那么最后每个根的答案就是 \max\{f(u,x)+g(u,k-x)\}。考虑转移,你从 g(u,*) 转移到一个儿子 g(v,*) 就是要添加所有 u 的其它儿子的 f,这可以维护前后缀的 f(u,*) 实现。

那么现在的问题就是下面三个背包数组的合并:u 的一段前缀儿子的 f 的合并的背包 l(*)u 的一段后缀儿子的 f 的合并的背包 r(*),和 g(u,*)。这样子直接合并就成 \Theta(nk^2) 的了,我们考虑用一点手法来优化。

注意到 g(v,x)x+sz_v<k 的时候是没用的,可以联系对根答案的贡献观察,发现这个的意思就是子树内外一共选不到 k 个点。因此假设我们要合并 lg(u)(没有 r),我们就能做到 \Theta(nk) 了。具体步骤是枚举最后 g(v,x),l(y)xy,这样由于 x 枚举量只有 sz_v(从 k-sz_vk),总枚举量就是和普通树上背包一样的 \Theta(nk)

那么多一个 r 怎么办?不妨考虑一个大胆的想法,直接开头令 r=g(u),然后这么干。这样子合并 l,r 时确实 \Theta(nk) 了,然而怎么每次合并 r 和一个子树 wf(w,*) 呢?那么我们发现,类似上面的推理,由于 l(x) 的大小也是很小的,所以 r(x) 也是有一个 x 的枚举下界的,它在 x+p<k 的时候是没有用的(p 为在 w-1 时背包数组 l 的大小)。因此我们此时枚举量就也是 \Theta(nk) 的(也是子树两两大小乘积和)。

这样我们就成功地做到了 \Theta(nk)。注意这个东西它只能求单点背包的值(也就是说要 k 给定),否则显然上面的一大串优化全都失去了效果。

不难发现刚刚的所有讨论都没有用到 \max\minx\cdot x=x 的幂等律,因此可以拓展到计数的问题,和最优化是完全一样的。

典型例题

这里有两个比较板的例题,暴力的背包 DP 极为容易写出,因此我不加赘述直接展示出换根背包的代码。

::::info[code]

#pragma GCC optimize("Ofast","inline","fast-math","unroll-loops")
#include <bits/stdc++.h>
#define rep(i, a, b) for (int i = (a), i##ABRACADABRA = (b); i <= i##ABRACADABRA; i++)
#define drep(i, a, b) for (int i = (a), i##ABRACADABRA = (b); i >= i##ABRACADABRA; i--)
using namespace std;
constexpr int I=1e9;

int n,k,m,sz[40010],son[40010],psz[40010],a[40010],tmp[3010],f[40010][3010],g[40010][3010],pre[40010][3010],suf[40010][3010];
vector<int>G[40010];

void dfs1(int u,int p){
  sz[u]=1;
  vector<int>gg;
  for (auto v:G[u])if (v^p){
    gg.push_back(v);
    dfs1(v,u);
    memset(tmp,0xc2,sizeof(tmp));
    rep(i,0,min(k,sz[u]))rep(j,0,min(k-i,sz[v])){
      int wu=f[u][i],wv=!j?0:f[v][j-1]+a[v];
      tmp[i+j]=max(tmp[i+j],wu+wv);
    }
    memcpy(f[u],tmp,sizeof(f[u]));
    sz[u]+=sz[v];
  }
  G[u].swap(gg);
}
void dfs2(int u){
  m=0;
  for (auto v:G[u])son[++m]=v;
  rep(i,1,m)psz[i]=psz[i-1]+sz[son[i]];
  rep(i,0,m+1){
    rep(j,1,k+1)pre[i][j]=suf[i][j]=-I;
    pre[i][0]=suf[i][0]=0;
  }
  rep(i,0,k+1)suf[m+1][i]=g[u][i];
  // cout<<u<<": ";
  // rep(i,0,k)cout<<g[u][i]<<" \n"[i==k];
  rep(t,1,m){
    int v=son[t];
    // cout<<v<<' '<<a[v]<<'\n';
    rep(i,0,min(k,psz[t-1]))rep(j,0,min(k-i,sz[v])){
      int wu=pre[t-1][i],wv=!j?0:f[v][j-1]+a[v];
      pre[t][i+j]=max(pre[t][i+j],wu+wv);
    }
  }
  drep(t,m,1){
    int v=son[t];
    // cout<<v<<' '<<a[v]<<'\n';
    rep(i,0,k)rep(j,max(0,k-psz[t-1]-i),min(k-i,sz[v])){
      int wu=suf[t+1][i],wv=!j?0:f[v][j-1]+a[v];
      suf[t][i+j]=max(suf[t][i+j],wu+wv);
    }
  }
  // 更新 g
  rep(t,1,m){
    int v=son[t];
    rep(i,0,min(k,psz[t-1]))rep(j,max(0,k-i-sz[son[t]]),k-i)
      g[v][i+j]=max(g[v][i+j],pre[t-1][i]+suf[t+1][j]);
    // cout<<"TMP "<<v<<": ";
    // rep(i,0,k)cout<<g[v][i]<<" \n"[i==k];
  }
  for (auto v:G[u]){
    drep(i,min(k,n-sz[v]),1)g[v][i]=g[v][i-1]+a[u];
    dfs2(v);
  }
}

void solve(){
  scanf("%d%d",&n,&k),--k;
  rep(i,1,n)scanf("%d",&a[i]);
  rep(i,0,n+1){
    G[i]={};
    rep(j,1,k+1)f[i][j]=g[i][j]=-I;
  }
  rep(i,1,n-1){
    int u,v;
    // u=i,v=i+1;
    scanf("%d%d",&u,&v);
    G[u].push_back(v);
    G[v].push_back(u);
  }
  dfs1(1,0),dfs2(1);
  rep(i,1,n){
    int ans=0;
    rep(j,0,k)ans=max(ans,f[i][j]+g[i][k-j]);
    printf("%d%c",ans+a[i]," \n"[i==n]);
  }
}

int main() {
  int tt=1;
  // scanf("%d",&tt);
  while (tt--)solve();
  return 0;
}

::::

::::info[code]

#include <bits/stdc++.h>
#define rep(i, a, b) for (int i = (a), i##ABRACADABRA = (b); i <= i##ABRACADABRA; i++)
#define drep(i, a, b) for (int i = (a), i##ABRACADABRA = (b); i >= i##ABRACADABRA; i--)
using namespace std;
using ll = long long;

int n,k,m,sz[4010],son[4010],psz[4010];
ll a[4010],tmp[4010],f[4010][4010],g[4010][4010],pre[4010][4010],suf[4010][4010];
vector<int>G[4010];

void dfs1(int u,int p){
  sz[u]=1;
  vector<int>gg;
  for (auto v:G[u])if (v^p){
    gg.push_back(v);
    dfs1(v,u);
    memset(tmp,0,sizeof(tmp));
    rep(i,0,min(k,sz[u]))rep(j,0,min(k-i,sz[v])){
      ll wu=f[u][i],wv=!j?f[v][j]:max(f[v][j],f[v][j-1]+j*a[v]);
      tmp[i+j]=max(tmp[i+j],wu+wv);
    }
    memcpy(f[u],tmp,sizeof(f[u]));
    sz[u]+=sz[v];
  }
  G[u].swap(gg);
}
void dfs2(int u){
  m=0;
  for (auto v:G[u])son[++m]=v;
  rep(i,1,m)psz[i]=psz[i-1]+sz[son[i]];
  rep(i,0,m+1)rep(j,0,k+1)pre[i][j]=suf[i][j]=0;
  rep(i,0,k+1)suf[m+1][i]=g[u][i];
  // cout<<u<<": ";
  // rep(i,0,k)cout<<g[u][i]<<" \n"[i==k];
  // 这里把祖先 DP 出的东西给了 suf。此时 pre 直接暴力转移即可,而 suf 要考虑的事情可就多了。
  rep(t,1,m){
    int v=son[t];
    // cout<<v<<' '<<a[v]<<'\n';
    rep(i,0,min(k,psz[t-1]))rep(j,0,min(k-i,sz[v])){
      ll wu=pre[t-1][i],wv=!j?f[v][j]:max(f[v][j],f[v][j-1]+j*a[v]);
      pre[t][i+j]=max(pre[t][i+j],wu+wv);
    }
  }
  drep(t,m,1){
    int v=son[t];
    // cout<<v<<' '<<a[v]<<'\n';
    rep(i,0,k)rep(j,max(0,k-psz[t-1]-i),min(k-i,sz[v])){
      // 这里 suf 不能使用 g[u] 相关的大小,但是发现 suf 和 pre 马上贡献给 g 的时候是 >=k-sz[t] 的,
      // 而 pre[t-1] 的范围 <=psz[t-1],那么显然 suf[t+1] 的范围 >=k-psz[t],也就是说处理 suf[t] 的时候只要枚举 k-psz[t-1]...k,这样子就复杂度对了
      ll wu=suf[t+1][i],wv=!j?f[v][j]:max(f[v][j],f[v][j-1]+j*a[v]);
      suf[t][i+j]=max(suf[t][i+j],wu+wv);
    }
  }
  // 更新 g
  rep(t,1,m){
    int v=son[t];
    rep(i,0,min(k,psz[t-1]))rep(j,max(0,k-i-sz[son[t]]),k-i)
      g[v][i+j]=max(g[v][i+j],pre[t-1][i]+suf[t+1][j]);
    // cout<<"TMP "<<v<<": ";
    // rep(i,0,k)cout<<g[v][i]<<" \n"[i==k];
  }
  for (auto v:G[u]){
    drep(i,min(k,n-sz[v]),1)g[v][i]=max(g[v][i],g[v][i-1]+i*a[u]);
    dfs2(v);
  }
}

void solve(){
  scanf("%d%d",&n,&k),--k;
  // n=4000,k=n-1;
  // rep(i,1,n)a[i]=i;
  rep(i,1,n)scanf("%lld",&a[i]);
  rep(i,0,n+1){
    G[i]={};
    rep(j,0,n+1)f[i][j]=g[i][j]=0;
  }
  rep(i,1,n-1){
    int u,v;
    // u=i,v=i+1;
    scanf("%d%d",&u,&v);
    G[u].push_back(v);
    G[v].push_back(u);
  }
  dfs1(1,0),dfs2(1);
  rep(i,1,n){
    ll ans=0;
    rep(j,0,k)ans=max(ans,f[i][j]+g[i][k-j]);
    printf("%lld%c",ans+(k+1)*a[i]," \n"[i==n]);
  }
}

int main() {
  int tt;
  scanf("%d",&tt);
  while (tt--)solve();
  return 0;
}

::::