P3574

· · 个人记录

[POI2014]FAR-FarmCraft

我们可以先做树形 DP,虽然从复杂度上来讲是不对的,但从某种意义上讲,我们明白了这个答案和遍历子树的顺序有着很大的关系。

我们令 f_p 表示 p 子树内的最大值(只是相对于子树的),令 siz_p 表示遍历子树 p 需要花费的时间。

那么我们假设 vp 能到达的一个子节点,就有这样一个转移:

f_p=\max\{f_p,siz_p+f_v+1\}

这里的 siz_p 表示遍历 v 之前的 p 的子节点所花费的时间,之所以加一是算上从 pv 的一段距离(我们前面记下的 siz_p 花费的时间包括了回到 p)。

然后我们知道,p 的子节点不只一个,那么就设另一个子节点为 u。假设,我们先遍历 u 再遍历 v,那么有如下转移:

f_p=\max\{f_p,siz_p+\max\{f_u,f_v+siz_u+2\}+1\}

这里加 2 是 u\rightarrow p\rightarrow v 的距离。

那么如果说我们先遍历 v 再遍历 u 呢?

f_p=\max\{f_p,siz_p+\max\{f_v,f_u+siz_v+2\}+1\}

假如说,我们先遍历 u 是更优的,那么从这个局部看,我们就能推出:

siz_p+\max\{f_u,f_v+siz_u+2\}+1<siz_p+\max\{f_v,f_u+siz_v+2\}+1

即:

\max\{f_u,f_v+siz_u+2\}<\max\{f_v,f_u+siz_v+2\}

又因为我们知道下面的条件:

那么可以大力分类讨论:

  1. f_u<f_v+siz_u+2,则 \max\{f_u,f_v+siz_u+2\}=f_v+siz_u+2,也就有 f_v+siz_u+2<\max\{f_v,f_u+siz_v+2\};又因为 f_v+siz_u+2>f_v,那么只有可能 f_v+siz_u+2<f_u+siz_v+2,即 f_v-siz_v<f_u-siz_u

  2. f_u\ge f_v+siz_u+2,又因为 f_u+siz_v+2>f_u,所以 f_u+siz_v+2>f_u\ge f_v+siz_u+2,即 f_u+siz_v+2>f_v+siz_u+2,即 f_v-siz_v<f_u-siz_u

综上,满足 f_v-siz_v<f_u-siz_u,则先遍历 u 是更优的。(先遍历 u 更优 \Leftrightarrow f_v-siz_v<f_u-siz_u

换句话说,这个遍历的顺序是可以贪心求解的。我们可以将 p 的子节点的 fsiz 求出,直接按照 f-siz 从大到小进行排序,然后再根据顺序更新 p 的信息。

因为所有点都只经过一次排序,遍历的复杂度是线性的,所以说总的时间复杂度是 O(n\log n)

然后就是一些细节问题,我们针对每颗子树,子树的根节点一定是第一个被送到电脑的;然而节点 1(即整棵树的根节点)却不是,它是最后一个被送到的,因此要单独拿出来,与子树内部的点比较再得出答案。

代码:

#include<iostream>
#include<cstdio>
#include<algorithm>
#define ll long long
using namespace std;

const ll N=5e5;

ll n,u,v,tot;

ll c[N+5],f[N+5],siz[N+5],tmp[N+5];

ll ver[N*2+5],nxt[N*2+5],head[N+5];

bool cmp(ll x,ll y) {
    return f[x]-siz[x]>f[y]-siz[y];
}

void dfs(ll p,ll fath) {
    if(p!=1) f[p]=c[p];
    for(ll i=head[p];i;i=nxt[i]) {
        if(ver[i]==fath) continue;
        dfs(ver[i],p);
    }
    ll cnt=0;
    for(ll i=head[p];i;i=nxt[i]) {
        if(ver[i]==fath) continue;
        tmp[++cnt]=ver[i];
    }
    sort(tmp+1,tmp+cnt+1,cmp);
    for(ll i=1;i<=cnt;i++) {
        f[p]=max(f[p],siz[p]+f[tmp[i]]+1);
        siz[p]+=siz[tmp[i]]+2;
    }
}

void add(ll u,ll v) {
    ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;
}

inline ll read() {
    ll ret=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') f=-f;ch=getchar();}
    while(ch>='0'&&ch<='9') {ret=(ret<<3)+(ret<<1)+ch-'0';ch=getchar();}
    return ret*f;
}

void write(ll x) {
    static char buf[22];static ll len=-1;
    if(x>=0) {
        do{buf[++len]=x%10+48;x/=10;}while(x);
    }
    else {
        putchar('-');
        do{buf[++len]=-(x%10)+48;x/=10;}while(x);
    }
    while(len>=0) putchar(buf[len--]);
}

int main() {

    n=read();

    for(ll i=1;i<=n;i++) {
        c[i]=read();
    }

    for(ll i=1;i<n;i++) {
        u=read();v=read();
        add(u,v);add(v,u);
    }

    dfs(1,0);

    write(max(f[1],siz[1]+c[1]));

    return 0;
}