P2680

· · 个人记录

[NOIP2015 提高组] 运输计划

关于各种手误。。。

午睡刚醒还是不太适合做题罢。。。

那么二分答案,挑选 >mid 的路径,将 cnt 加一,并将路径上的所有边权值加一,这个可以用差分实现。最后看是否有权值为 cnt 的边,并且看将这条边的 t 赋 0 后是否是原来的最长路径的长度 \le mid

反正是有点拐的二分答案,要想出来还要动点脑子。。

时间复杂度 O((m+n)\log \sum t)

代码:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#include<algorithm>
#include<cmath>
#define ll long long
using namespace std;

const ll N=3e5;

ll n,m,u,v,w,r,l,mid,tot,flg,ma,cnt;

ll fa[N+5][25],len[N+5],lg[N+5],ver[N*2+5],wt[N*2+5],nxt[N*2+5],head[N+5],c[N+5],d[N+5],dt[N+5],dis[N+5],lca_[N+5],a[N+5],b[N+5];

void _dfs(ll p,ll fath) {
    c[p]=d[p];
    for(ll i=head[p];i;i=nxt[i]) {
        if(ver[i]==fath) continue;
        _dfs(ver[i],p);
        if(flg==1) return;
        c[p]=c[p]+c[ver[i]];
    }
    if(c[p]>=cnt&&p!=1) {
        if(dt[p]-dt[fath]>=ma) {
            flg=1;
        }
    }
}

bool check(ll mid) {
    cnt=0;memset(d,0,sizeof(d));
    memset(c,0,sizeof(c));
    ma=0;flg=0;
    for(ll i=1;i<=m;i++) {
        if(len[i]>mid) {
            d[a[i]]++;d[b[i]]++;d[lca_[i]]-=2;
            cnt++;ma=max(ma,len[i]-mid);
        }
    }
    _dfs(1,0);
    if(cnt==0) return 1;
    if(flg==1) return 1;
    return 0;
}

void dfs(ll p,ll fath) {
    fa[p][0]=fath;dis[p]=dis[fath]+1;
    for(ll i=1;i<=lg[dis[p]];i++) {
        fa[p][i]=fa[fa[p][i-1]][i-1];
    }
    for(ll i=head[p];i;i=nxt[i]) {
        if(ver[i]==fath) continue;
        dt[ver[i]]=dt[p]+wt[i];dfs(ver[i],p);
    }
}

ll lca(ll a,ll b) {
    if(dis[a]<dis[b]) swap(a,b);
    while(dis[a]>dis[b]) a=fa[a][lg[dis[a]-dis[b]]-1];
    if(a==b) return a;
    for(ll k=lg[dis[a]]-1;k>=0;k--) {
        if(fa[a][k]!=fa[b][k]) {
            a=fa[a][k];b=fa[b][k];
        }
    }
    return fa[a][0];
}

void add(ll u,ll v,ll w) {
    ver[++tot]=v;wt[tot]=w;nxt[tot]=head[u];head[u]=tot;
}

inline ll read() {
    ll ret=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') f=-f;ch=getchar();}
    while(ch>='0'&&ch<='9') {ret=(ret<<3)+(ret<<1)+ch-'0';ch=getchar();}
    return ret*f;
}

void write(ll x) {
    if(x<0) {x=-x;putchar('-');}
    if(x>9) write(x/10);
    putchar(x%10+48);
}

int main() {

    n=read();m=read();

    for(ll i=1;i<n;i++) {
        u=read();v=read();w=read();
        r=r+w;
        add(u,v,w);add(v,u,w);
    }

    for(ll i=1;i<=n;i++) lg[i]=lg[i-1]+(1<<lg[i-1]==i);

    dfs(1,0);

    for(ll i=1;i<=m;i++) {
        a[i]=read();b[i]=read();
        lca_[i]=lca(a[i],b[i]);
        len[i]=dt[a[i]]-dt[lca_[i]]+dt[b[i]]-dt[lca_[i]];
    }

    while(l<r) {
        mid=(l+r)/2;
        if(check(mid)) r=mid;
        else l=mid+1;
    }

    write(l);

    return 0;
}