题解:AT_icpc2013spring_e 最小生成树

· · 题解

双倍经验

link1(gesp 9月 T2) 题库更新后会加入。

link2(Atcoder)

前言(废话

本题解相对于其他另外两篇的题解更适合蒟蒻食用,有清楚的图解和 AC 代码示例。

本蒟蒻在 gesp 考试时遇到本题,因为没想到会考树剖导致写挂痛失近 500 rmb 。

(引用枚金牌大佬名言“不在考纲里就不算超纲”警示各位 oier 好好练题)。

思路

题目给定一个图,并且要求出删去某一条边后,图中的最小生成树。

有一种很 O(n^2) 的暴力做法,每次标记一条边不取,去求图中最小生成树。显然过不掉这题。

我们考虑优化每次删掉一条边如何快速求出一条“替代边”。我们先求出不删边的最小生成树,显然,对于那些不在最小生成树中的边,删掉也不会有什么影响。

那我们考虑对于不在树上的边该如何处理。用下面这三张图举例子:

这颗树是我们的最小生成树。

而这是一条不在树上的边(红色的)

对于这条“非树边”我们可以替换掉哪些边呢?显然是蓝色的边(如下)

由此我们知道了了在最小生成树上的边该如何解决。我们可以先求出原图中的最小生成树。并维护“非树边”的覆盖,我们可以通过树剖来维护每一条“树边”可替换的最短边。我们能在 O(\log^2 n) 的复杂度内添加一条“非树边”。然后对于每次替换我们可以在 O(\log n) 的复杂度内查询替换边。这样我们就做完这道题了。

坑点

  1. 树剖维护的是以这个点为子节点的边,在添加边时只能从 u,v 添加到 lca(u,v)的子节点。(意思是 lca(u,v) 并不会被添加)。
  2. 本蒟蒻认为树剖本身就具有许多坑点,需要多加小心。

AC代码

#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10;
int tag[4*N],mi[4*N];
int n;
void push_up(int op){
    mi[op]=min(mi[op*2],mi[op*2+1]);
    return;
}
void biuld(int l,int r,int op){
    tag[op]=2e9;
    if(l==r){
        mi[op]=2e9;
        return;
    }
    int mid=(l+r)/2;
    biuld(l,mid,op*2);
    biuld(mid+1,r,op*2+1);
    push_up(op);
    return;
} 
void add(int l,int r,int s,int t,int op,int k){
    if(s>=l&&t<=r){
        mi[op]=min(mi[op],k);
        tag[op]=min(tag[op],k);
        return;
    }
    int mid=(s+t)/2;
    if(tag[op]!=2e9){
        mi[op*2]=min(mi[op*2],tag[op]);
        mi[op*2+1]=min(mi[op*2+1],tag[op]);
        tag[op*2]=min(tag[op*2],tag[op]);
        tag[op*2+1]=min(tag[op*2+1],tag[op]);
        tag[op]=2e9;
    }
    if(l<=mid) add(l,r,s,mid,op*2,k);
    if(r>mid) add(l,r,mid+1,t,op*2+1,k);
    push_up(op);
    return;
}
int ask(int l,int r,int op,int p){
    if(l==r){
        return mi[op];
    }
    int mid=(l+r)/2;
    if(tag[op]!=2e9){
        mi[op*2]=min(mi[op*2],tag[op]);
        mi[op*2+1]=min(mi[op*2+1],tag[op]);
        tag[op*2]=min(tag[op*2],tag[op]);
        tag[op*2+1]=min(tag[op*2+1],tag[op]);
        tag[op]=2e9;
    }
    int res=2e9;
    if(p<=mid) res=min(res,ask(l,mid,op*2,p));
    else res=min(res,ask(mid+1,r,op*2+1,p));
    push_up(op);
    return res;
}
vector<int> g[N];
int hson[N],siz[N],top[N],dep[N],dfn[N],fa[N];
int new_ct;
int dfs1(int p,int f){
    fa[p]=f;
    dep[p]=dep[f]+1;
    siz[p]=1;
    int tmp=0;
    for(int i=0;i<g[p].size();i++){
        if(g[p][i]!=f){
            tmp=dfs1(g[p][i],p);
            siz[p]+=tmp;
            if(tmp>siz[hson[p]]){
                hson[p]=g[p][i];
            }
        }   
    } 
    return siz[p];
}
void dfs2(int p,int f,int tp){
    new_ct++;
    dfn[p]=new_ct;
    top[p]=tp;
    if(hson[p]) dfs2(hson[p],p,tp);
    for(int i=0;i<g[p].size();i++){
        if(g[p][i]!=f&&g[p][i]!=hson[p]){
            dfs2(g[p][i],p,g[p][i]);
        }
    }
    return;
}
void lca(int u,int v,int k){
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]]) swap(u,v);
        add(dfn[top[u]],dfn[u],1,n,1,k);
        u=fa[top[u]];
    }
    if(dep[u]<dep[v]) swap(u,v);
    add(dfn[v]+1,dfn[u],1,n,1,k);//不加lca;
    return; 
}
int fat[N];
int find(int a){
    if(fat[a]==a) return a;
    return fat[a]=find(fat[a]);
}
struct side{
    int u,v,w;
    int op;
};
side s[N];
bool xz[N];
bool cmp(side a,side b){
    return a.w<b.w;
}
int ans[N];
int main(){
    int m;
    cin>>n>>m;
    for(int i=1;i<=n;i++){
        fat[i]=i;
    }
    for(int i=1;i<=m;i++){
        cin>>s[i].u>>s[i].v>>s[i].w;
        s[i].op=i;
    } 
    sort(s+1,s+1+m,cmp);
    int u,v,w;
    int cnt=0;
    int tans=0;
    for(int i=1;i<=m;i++){
        u=s[i].u,v=s[i].v,w=s[i].w;
        int fu=find(u),fv=find(v);
        if(fu!=fv){
            fat[fu]=fat[fv];
            cnt++;
            tans+=w;
            xz[i]=1;
            g[u].push_back(v);
            g[v].push_back(u);
        }
        if(cnt==n-1){
            break;
        }
    }
    dfs1(1,0);
    dfs2(1,0,1);
    biuld(1,n,1);
    for(int i=1;i<=m;i++){
        if(xz[i]!=1){
            ans[s[i].op]=tans;
            lca(s[i].u,s[i].v,s[i].w);
        }
    }
    int pt;
    int gx;
    for(int i=1;i<=m;i++){
        if(xz[i]==1){
            u=s[i].u,v=s[i].v,w=s[i].w;
            pt=dep[u]>dep[v]?u:v;
            gx=ask(1,n,1,dfn[pt]);
            if(gx==2e9){
                ans[s[i].op]=-1;
                continue;
            }
            ans[s[i].op]=tans-s[i].w+gx;
        }
    }
    for(int i=1;i<=m;i++){
        cout<<ans[i]<<"\n";
    }
    return 0;
}