题解 P1273 【有线电视网】
写在前面
树形DP一般就是
f[i][j] 表示以第i 个结点为根的子树中,j 个儿子或者价值为j ,在这道题里面,j 表示在这个子树中选取j 个叶子节点。
在AC这道题之前,看过无数题解,这句话是笔者认为写的最精彩的。
状态
考虑一钟状态设计方式
方程
就是说,选择一部分的
思考,如果让儿子结点
为什么是分组背包
分组背包指在很多组中只能选一种物品,获得的最大利润
回到本题,组就是每一棵子树,物品其实并不是叶子结点,而是叶子节点的满足个数。
比如,在以结点
- 满足以结点
u 为根的子树中,1 个叶子结点的需求。 - 满足以结点
u 为根的子树中,2 个叶子结点的需求。 - 满足以结点
u 为根的子树中,3 个叶子结点的需求。 - 满足以结点
u 为根的子树中,4 个叶子结点的需求。 - 满足以结点
u 为根的子树中,Size_v 个叶子结点的需求。
每一组的各个元素都是互相矛盾只能选一种的。
这就是分组背包。
对于每一个结点都要进行分组背包。
对于扩展结点时,假如当前结点时
应该使用
int &sum = lef[now];//表示子树大小
for(_R int i = head[now];i;i = edge[i].nxt) {
int exNode = edge[i].node;
int t = dfs(exNode); sum += t;
for(_R int j = lef[now];j >= 0;j--) {
for(_R int k = 0;k <= lef[exNode];k++) {
if(j - k < 0) continue;
dp[now][j] = max(dp[now][j], dp[now][j - k] + dp[exNode][k] - edge[i].w);
}
关于for(int j = lef[now];j >= 0;j--)
倒序枚举的原因同0-1背包滚动数组优化后的方法。
防止多次转移,也就是防止在同一组里选重复的元素。
即,
保证每一次
最后的最后
题目没有问能获得的最大利润,而是希望更多的用户能够用上电视,于是只要电视台不亏本,就可以满足。
从打到小枚举
Codes
#include<cstdio>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
#define _R register
#define inf 0x7fffffff
using namespace std;
const int _ = 3100;
inline int read()
{
char c = getchar(); int sign = 1; int x = 0;
while(c > '9' || c < '0') { if(c=='-')sign = -1; c = getchar(); }
while(c <= '9' && c >= '0') { x *= 10; x += c - '0'; c = getchar(); }
return x * sign;
}
int n, m;
int nodeVal[_];
int dp[_][_];
int head[_];
struct edge{
int node;
int w;
int nxt;
}edge[_];
int tot = 0;
void add(int u, int v, int w){
edge[++tot].nxt = head[u];
head[u] = tot;
edge[tot].node = v;
edge[tot].w = w;
}
int lef[_];
int dfs(int now)
{
if(now >= n - m + 1) return dp[now][1] = nodeVal[now], lef[now] = 1;
int &sum = lef[now];
for(_R int i = head[now];i;i = edge[i].nxt) {
int exNode = edge[i].node;
int t = dfs(exNode); sum += t;
for(_R int j = lef[now];j >= 0;j--) {
for(_R int k = 0;k <= lef[exNode];k++) {
if(j - k < 0) continue;
dp[now][j] = max(dp[now][j], dp[now][j - k] + dp[exNode][k] - edge[i].w);
}
}
}
return lef[now];
}
int main()
{
n = read(), m = read();
for(_R int i = 1;i <= n - m;i++){
int k = read();
for(_R int j = 1;j <= k;j++){
int A = read();
int C = read();
add(i, A, C);
}
}
for(_R int i = n - m + 1;i <= n;i++){
nodeVal[i] = read();
}
memset(dp, -100, sizeof(dp));
for(_R int i = 1;i <= n;i++) dp[i][0] = 0;
dfs(1);
for(_R int i = m;i >= 0;i--)
if(dp[1][i] >= 0){
return printf("%d\n", i), 0;
}
return 0;
}