题解 P3023 【[USACO11OPEN]焊接Soldering】

· · 题解

英文题解(最下有)太长看不下去就去请教了一下大神,然后优化了一下O(N²)的算法就水过了。

状态f[i][j]表示,以i为根的子树,唯一一根接到i并且会在后面继续往上接的电线的当前长度为j的情况下的最小代价

转移有三种(蓝色为当前加入边)

1.蓝色电线截断在i

2.蓝色电线取代之前的电线接到子树外,红色电线截断在i

3.蓝色电线截断在i,并与红色电线合并成一条电线,另外绿色电线不一定如图连到子树外

【优化】

用队列维护f(建议用vector),维护信息j和f[i][j],对于每个深度只需要维护一个队列,每当一个孩子子树信息更新了队列信息后,删去队列内的两种无效点,一是对于相同的j只留下一个最优点,二是对于jx>jy并且fx[i][j]>fy[i][j],把x删除。PS:正解是找到某些性质然后用凸包优化树上dp

【呆马】

#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstdlib>
#include<iostream>
#include<vector>
#define ll long long
using namespace std;
const int N=50001;
struct st
{
    int l; ll co;
    st(){}
    st(int x,ll y){l=x; co=y;}
    ll cal(){return co+l*l;}
    bool operator <(const st x) const{return l<x.l || (l==x.l && co<x.co);}
};
vector <st> f[N],g;
vector <int> l[N];
int n,i,x,y;
bool vis[N];
void dp(int x,int d)
{
    vis[x]=1;
    f[d].clear();
    bool leaf=1;
    int m1=l[x].size(),m2,m3;
    for (int k=0;k<m1;k++)
        if (!vis[l[x][k]])
        {
            leaf=0;
            dp(l[x][k],d+1);
            m3=f[d+1].size();
            for (int i=0;i<m3;i++) f[d+1][i].l++;
            if (!f[d].size()){f[d]=f[d+1]; continue;}
            g.clear();
            m2=f[d].size();
            for (int i=0;i<m2;i++)
                for (int j=0;j<m3;j++)
                    if (!f[d][i].l) g.push_back(st(0,f[d][i].co+f[d+1][j].cal()));
                    else
                    {
                        g.push_back(st(f[d][i].l,f[d][i].co+f[d+1][j].cal()));
                        g.push_back(st(f[d+1][j].l,f[d+1][j].co+f[d][i].cal()));
                        g.push_back(st(0,f[d][i].cal()+f[d+1][j].cal()+((f[d][i].l*f[d+1][j].l)<<1)));
                    }
            sort(g.begin(),g.end());
            int num=0;
            m2=g.size();
            for (int i=1;i<m2;i++)
                if (g[i].cal()<g[num].cal()) g[++num]=g[i];
            g.resize(num+1);
            f[d]=g;
        }
    if (leaf) f[d].push_back(st(0,0));
}
int main()
{
        scanf("%d\n",&n);
        for (i=1;i<n;i++)
        {
            scanf("%d%d\n",&x,&y);
            l[x].push_back(y);
            l[y].push_back(x);
        }
        dp(1,0);
        printf("%lld",f[0][0].cal());
}