P1552

· · 个人记录

[APIO2012]派遣

我们要求对于每颗子树,取其中部分节点的薪水和 \le m,且使这个节点数尽量大,接着比较即可。

那么可以想象,对于每颗子树,我们只取薪水要求最少的几个节点,就可以找到派遣数量的最大值。

如果直接 dfs 后暴力排序,显然是会炸的。如果排序改为归并,大概可以过 30% 的数据。

我们可以发现,在归并的过程中,很多点是没用的。

比如说我们已经确定了某一颗子树的最佳人员派遣方案,这颗子树中未被选入派遣名单的成员在之后的合并中亦不可能入选名单,因为它加入之后一定会使答案更劣。

那么我们干脆把不可能入选后续名单的人直接删掉,达到优化的目的。

于是方向就很明确了,使用左偏树(大根堆)进行合并,然后在每次合并之后把较大薪水的几个点从堆中删除,直到剩下的人全部派遣的薪水 \le m,这需要我们在删除的过程中同时维护成员数和薪水总额两个信息。

因为每个点至多被删除一次,每颗子树至多被合并一次,每条边连接的两点寻找堆顶各一次。因此总的时间复杂度是 O(n\log n) 的。

代码:

#include<iostream>
#include<cstdio>
#include<algorithm>
#define ll long long
using namespace std;

const ll N=1e5;

ll n,m,rt,x,ans,tot;

ll siz[N+5],sum[N+5],c[N+5],l[N+5];

ll ver[N+5],nxt[N+5],head[N+5];

struct lh{
    ll dis,val,rs,ls,rt;
}s[N+5];

ll merge(ll x,ll y) {
    if(!x||!y) return x+y;
    if(s[x].val<s[y].val||(s[x].val==s[y].val&&x>y)) swap(x,y);
    s[x].rs=merge(s[x].rs,y);
    if(s[s[x].ls].dis<s[s[x].rs].dis) swap(s[x].ls,s[x].rs);
    s[s[x].ls].rt=s[s[x].rs].rt=s[x].rt=x;
    s[x].dis=s[s[x].rs].dis+1;return x;
}

ll get(ll x) {
    return s[x].rt==x?x:s[x].rt=get(s[x].rt);
}

void pop(ll x) {
    s[x].val=-1;
    s[s[x].ls].rt=s[x].ls;s[s[x].rs].rt=s[x].rs;
    s[x].rt=merge(s[x].ls,s[x].rs);
}

void dfs(ll p) {
    s[p].rt=p;s[p].val=c[p];
    siz[p]=1;sum[p]=c[p];
    for(ll i=head[p];i;i=nxt[i]) {
        dfs(ver[i]);
        siz[p]+=siz[ver[i]];
        sum[p]+=sum[ver[i]];
        ll f1=get(p),f2=get(ver[i]);
        if(f1!=f2) s[f1].rt=s[f2].rt=merge(f1,f2);
        while(sum[p]>m) {
            sum[p]-=s[get(p)].val;siz[p]--;
            pop(get(p));
        }
    }
    ll tmp=siz[p]*l[p];
    if(tmp>ans) ans=tmp;
}

void add(ll u,ll v) {
    ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;
}

inline ll read() {
    ll ret=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') f=-f;ch=getchar();}
    while(ch>='0'&&ch<='9') {ret=(ret<<3)+(ret<<1)+ch-'0';ch=getchar();}
    return ret*f;
}

void write(ll x) {
    static char buf[22];static ll len=-1;
    if(x>=0) {
        do{buf[++len]=x%10+48;x/=10;}while(x);
    }
    else {
        putchar('-');
        do{buf[++len]=-(x%10)+48;x/=10;}while(x);
    }
    while(len>=0) putchar(buf[len--]);
}

int main() {

    n=read();m=read();

    for(ll i=1;i<=n;i++) {
        x=read();
        if(x) add(x,i);
        else rt=i;
        c[i]=read();l[i]=read();
    }

    dfs(rt);

    write(ans);

    return 0;
}