题解:P14622 [2019 KAIST RUN Fall] Wind of Change

· · 题解

点分治板子。

双树相应点路径问题,一般解法是对于一棵树分治,维护另一棵树的数据结构。

对于 T_1 树分治,建出当前分治层所有点在 T_2 上的虚树,记分治中心为 rt,则每个点 u 的加权为 dis(rt,u),扫描虚树一遍求一下加权的最近点即可。

#include "bits/stdc++.h"
using namespace std;
typedef long long ll1;
#define pii pair<int,int>
#define pil pair<int,ll1>
#define mkp make_pair
#define fir first
#define sec second
const int N=1e6+5,M=4e6+5;
const ll1 inf=1e17;
vector<pil>e[N],e1[N];
int n;
ll1 dep[N];
namespace LCA{
    int dfn[N];
    int Log2[N],fa[N],d[N],dfns[19][N],tdfn;
    int cmpr(int x,int y){return d[x]<d[y]?x:y;}
    void predfs(int u,int Fa,int D,ll1 Dep){
        d[u]=D,dep[u]=Dep,fa[u]=Fa;
        dfn[u]=++tdfn,dfns[0][tdfn]=u;
        for(auto [v,w]:e[u])if(v!=Fa)predfs(v,u,D+1,Dep+w);
    }
    void build(){
        for(int i=2;i<=n;i++){
            Log2[i]=Log2[i-1];
            if(i==(1<<(Log2[i]+1)))Log2[i]++;
        }
        predfs(n+1,0,0,0);
        for(int k=1;(1<<k)<=n;k++)for(int i=1;i+(1<<k)-1<=n;i++)
            dfns[k][i]=cmpr(dfns[k-1][i],dfns[k-1][i+(1<<(k-1))]);
    }
    int Lca(int x,int y){
        if(x==y)return x;
        x=dfn[x],y=dfn[y];if(x>y)swap(x,y);
        x++;int k=Log2[y-x+1];
        return fa[cmpr(dfns[k][x],dfns[k][y-(1<<k)+1])];
    }
}
bool cmpdfn(int x,int y){return LCA::dfn[x]<LCA::dfn[y];}
ll1 Ans[N],dis[N];
int dfn[N],mxd[N],dfns[N],p[N],totp,mns,root,tdfn;
bool vis[N],inp[N];

void ae1(int u,int v){
    e1[u].emplace_back(mkp(v,dep[v]-dep[u]));
}
void dfs1(int u){
    dfn[u]=++tdfn,dfns[tdfn]=u;
    for(auto [v,w]:e1[u])dfs1(v);
    mxd[u]=tdfn;
}
#define ls (o<<1)
#define rs (ls|1)
#define mid ((l+r)>>1)
namespace sgt{
    int L[M],R[M];ll1 mn[M],ad[M];
    void pup(int o){mn[o]=min(mn[ls],mn[rs]);}
    void build(int o,int l,int r){
        L[o]=l,R[o]=r;ad[o]=0;
        if(l==r){mn[o]=dep[dfns[l]]+dis[dfns[l]];return;}
        build(ls,l,mid),build(rs,mid+1,r);
        pup(o);
    }
    void Add(int o,ll1 x){mn[o]+=x,ad[o]+=x;}
    void psd(int o){if(ad[o])Add(ls,ad[o]),Add(rs,ad[o]),ad[o]=0;}
    void add(int o,int lt,int rt,ll1 x){
        if(lt>rt)return;
        int l=L[o],r=R[o];
        if(l>=lt&&r<=rt)return Add(o,x);
        psd(o);
        if(lt<=mid)add(ls,lt,rt,x);
        if(rt> mid)add(rs,lt,rt,x);
        pup(o);
    }
    ll1 qrymn(int o,int lt,int rt){
        if(lt>rt)return inf;
        int l=L[o],r=R[o];
        if(l>=lt&&r<=rt)return mn[o];
        psd(o);ll1 res=inf;
        if(lt<=mid)res=min(res,qrymn(ls,lt,rt)); 
        if(rt> mid)res=min(res,qrymn(rs,lt,rt));
        return res;
    }
}
#undef ls
#undef rs
#undef mid
void dfs2(int u){
    if(inp[u])
        Ans[u-n]=min(Ans[u-n],min(sgt::qrymn(1,1,dfn[u]-1),sgt::qrymn(1,dfn[u]+1,tdfn))+dis[u]);
    for(auto [v,w]:e1[u]){
        sgt::add(1,1,tdfn,w);
        sgt::add(1,dfn[v],mxd[v],-2ll*w);
        dfs2(v);
        sgt::add(1,1,tdfn,-w);
        sgt::add(1,dfn[v],mxd[v],2ll*w);
    }
}
void Calc(){
    for(int i=1;i<=totp;i++)inp[p[i]]=1;
    sort(p+1,p+1+totp,cmpdfn);
    for(int i=totp;i>1;i--)p[++totp]=LCA::Lca(p[i],p[i-1]);
    p[++totp]=n+1;
    sort(p+1,p+1+totp,cmpdfn);
    totp=unique(p+1,p+1+totp)-p-1;
    for(int i=1;i<=totp;i++)
        e1[p[i]].clear();
    for(int i=1;i<=totp;i++)
        if(!inp[p[i]])dis[p[i]]=inf;
    for(int i=2;i<=totp;i++)
        ae1(LCA::Lca(p[i],p[i-1]),p[i]);
    tdfn=0;dfs1(n+1);
    sgt::build(1,1,tdfn);
    dfs2(n+1);
    for(int i=1;i<=totp;i++)inp[p[i]]=0;
}

void counts(int u,int Fa){
    ++totp;
    for(auto [v,w]:e[u])if(v!=Fa&&!vis[v])
        counts(v,u);
}
int fdrt(int u,int Fa){
    int sz=1,tsz=0,mxsz=0;
    for(auto [v,w]:e[u])if(v!=Fa&&!vis[v]){
        tsz=fdrt(v,u);
        sz+=tsz,mxsz=max(mxsz,tsz);
    }
    mxsz=max(mxsz,totp-sz);
    if(mxsz<=mns)root=u,mns=mxsz;
    return sz;
}
void getdis(int u,int Fa,ll1 D){
    dis[u+n]=D,p[++totp]=u+n;
    for(auto [v,w]:e[u])if(v!=Fa&&!vis[v])
        getdis(v,u,D+w);
}
void solve(int u){
    totp=0,mns=N,root=0;
    counts(u,0),fdrt(u,0);
    u=root;
    totp=0;getdis(u,0,0);
    Calc();
    vis[u]=1;
    for(auto [v,w]:e[u])if(!vis[v])solve(v);
}
void work(){
    for(int i=1;i<=n;i++)Ans[i]=inf;
    solve(1);
    for(int i=1;i<=n;i++)
        printf("%lld\n",Ans[i]);
}
int main(){
    scanf("%d",&n);
    for(int i=1,u,v,w;i<n;i++){
        scanf("%d%d%d",&u,&v,&w);
        e[u].emplace_back(mkp(v,w));
        e[v].emplace_back(mkp(u,w));
    }
    for(int i=1,u,v,w;i<n;i++){
        scanf("%d%d%d",&u,&v,&w),u+=n,v+=n;
        e[u].emplace_back(mkp(v,w));
        e[v].emplace_back(mkp(u,w));
    }
    LCA::build();
    work();
    return 0;
}