【题解】P1131 [ZJOI2007]时态同步

· · 个人记录

看到大家的题解都是在回溯时计算并统计答案,这里提供一种不一样的思路:在自上向下dfs时贪心。

我们的目标状态是让所有叶子节点到根节点距离相等,故这题可以理解为贪心,因为边权只能加不能减,所以最后的距离一定是初始状态所有叶子节点到根节点距离的最大值,把这个值设为 max

正如前面神仙的题解所说,一定是尽量往靠近根节点的边上添加边权更优,因为一次加边权可以缩小更多叶子节点到根节点距离与 max 的差。

因此我们的决策就确定了:我们可以维护每个子树内叶节点到根(这里指的不是到子树的根,而是到整棵树的根 S 的最大距离,将它设为 mx[u] 。 设目前枚举到节点 u ,只要 mx[u] < max ,那么就贪心地 把u 到它父亲 fa 的这条边的边权加上 max-mx[u] ,以使得 mx[u] = max ,如此不断正序递归下去,每次答案都增加 max-mx[u] ,就能统计出最后的答案。

当我们加上 ufa 的边权时,子树内所有的 mx[son] 都会加上 max-mx[u] 但我们在递归下至子节点时并不知道祖先加了多少,如果每次操作都去子树内暴力修改 mx 显然会超时,所以我们可以在dfs时维护一个 tag ,每次把 mx[u] 加上 tag ,表示由祖先的修改使得它的 mx 值多了 tag ,这样我们增加的边权即为 max-(mx[u]+tag) 。每次修改后再将 tag 加上修改过的max-(mx[u]+tag) 后向下递归,这样我们就得出了最优解。

#include<bits/stdc++.h>
using namespace std;
#define int long long

inline void read(int &num)
{
    num = 0;
    int f=1;
    char c=getchar();
    while(!isdigit(c))
    {
        if(c=='-')
            f=-1;
        c=getchar();
    }
    while(isdigit(c))
    {
        num=num*10+c-48;
        c=getchar();
    }
    num*=f;
}

const int maxn = 1e6+7;

int n,tot,rt,x,y,z,mmx,ans;
int hed[maxn],to[maxn],nxt[maxn],edg[maxn],cst[maxn],mx[maxn];

void add(int x,int y,int z)
{
    to[++tot]=y;
    edg[tot]=z;
    nxt[tot]=hed[x];
    hed[x]=tot;
}

void gtcst(int x,int fa)
{
    mmx=max(mmx,cst[x]);//求出最远距离 
    for(int i=hed[x];i;i=nxt[i])
    {
        int y=to[i],z=edg[i];
        if(y==fa)
            continue;
        cst[y]=cst[x]+z;//cst表示每个节点到根节点的距离 
        gtcst(y,x);
    }
}

void gtmx(int x,int fa)
{
    mx[x]=cst[x];
    for(int i=hed[x];i;i=nxt[i])
    {
        int y=to[i];
        if(y==fa)
            continue;
        gtmx(y,x);
        mx[x]=max(mx[x],mx[y]);//每个子树内的最远距离 
    }
}

void dfs(int x,int fa,int tag)
{
    for(int i=hed[x];i;i=nxt[i])
    {
        int y=to[i],z=edg[i];
        int cha=0;
        if(y==fa)
            continue;
        if(mx[y]+tag!=mmx)//这里就是关键的部分: 将前面修改的贡献累加到当前的mx上 
        {
            cha=mmx-mx[y]-tag;
            ans+=cha;//统计答案 
        }
        dfs(y,x,tag+cha);//tag加上修改的权值,表示下面子树的所有节点的mx都要加上新的tag值 
    }
}

signed main()
{
    read(n);
    read(rt);
    for(int i=1;i<n;++i)
        read(x),read(y),read(z),add(x,y,z),add(y,x,z);
    gtcst(rt,0);
    gtmx(rt,0);
    dfs(rt,0,0);
    printf("%lld\n",ans);
    return 0;
}

自以为讲得还是比较透彻的,如果有不清楚的可以私信问我,随时乐意回答