题解:CF2174D Secret Message

· · 题解

前言

个人感觉这道题的真实难度完全到不了 3400。感谢 Getaway_Car 大佬对 corner case 的点拨。

初步转化

首先考虑跑一下 Kruskal,若前 n-1 条边未构成树,那么直接输出它们的和。考虑类似于求次小生成树的方法,从最小生成树上替换一条边使新的图不为一棵树。

计算方法

直接枚举每一条边,设两个端点为 uv。枚举一条不在 uv 路径上的一边,并替换成这条边。由要求答案最小可知,替换下的边一定是能换下的边中边权最大的。暴力计算时间复杂度 O(n^2)

数据结构维护

以 1 为根,考虑拆解每条边换下会产生的贡献,分为以下几种情况,设正在拆解节点 u 的父边的贡献:

1.新边两个端点都在以 u 为根的子树内:这种情况可以先计算每个点到根节点经过边的最大值。然后对于每条新边,用两端点 lca 处记录的值更新答案。

2.都不在子树内:由于子树内 dfn 序连续,所以不在子树内的 dfn 序为 [1,dfn_u-1][dfn_u+sz_u,n]。根据这两个连续区间进行离线扫描,由于只需维护前缀或后缀最大值,所以可以使用树状数组维护单点修改,前缀取 \max

这样这道题的主要部分就完成了,时间复杂度 O(n\log n+m\log n)

Corner Case

显然到这里为止这道题都较容易,但是为什么会 Wrong answer on test 2 的第 39 行?细心地回顾过程,我们发现第一步就可能存在问题。为什么最优解一定只替换一条边?

我们跑 Kruskal 时,选取了前 n-1 条边作为树边。显然,删多条边的最优情况是换下第 n-2n-1 条边,换上第 nn+1 条边。我们先从原树上删下这两条边,此时的树被分为了三个连通块。考虑加入新的两条边,分为以下情况:

1.加入的边不构成树:那么用新的边权和更新答案,由于其他删多边的情况一定更劣,所以不用考虑。

2.加入新边后重新构成树:三个连通块之间原有两条边,新也有两条边。那么一定存在一对连通块之间,原有边 x,新也有边 z,设删除的两条边为 xy。考虑一种新的换边方式:删除 y,加入 z,只用删一条边。这样同样不构成树,答案却更优。可以得出所有删多条边的情况一定不是最优解。

由于比最优解更大无害,所以直接拿 sum-e_{n-2}.w-e_{n-1}.w+e_{n}.w+e_{n+1}.w 更新答案即可。

代码

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <set>
#define int long long
using namespace std;
int read()
{
    char c=getchar();
    int f=1,x=0;
    while(c<'0'||c>'9')
    {
        if(c=='-') f=-1;
        c=getchar();
    }
    while(c>='0'&&c<='9')
    {
        x=(x<<1)+(x<<3)+(c^'0');
        c=getchar();
    }
    return x*f;
}
void print(int x)
{
    if(x<0)
    {
        putchar('-');
        x=-x;
    }
    if(x>9) print(x/10);
    putchar(x%10+'0');
}
const int N=2e5+5,inf=1e18;
int T,n,m,sum,cnt,tot,ct,mx[N],f[N],dfn[N],sz[N],fa[20][N],dep[N];
struct edge{int u,v,w;}e[N],ne[N];
vector<pair<int,int>> g[N];
struct dsu
{
    int fa[N];
    void init(){for(int i=1;i<=n;i++) fa[i]=i;}
    int findfa(int x)
    {
        if(x==fa[x]) return x;
        return fa[x]=findfa(fa[x]);
    }
}d;
struct line{int x,y,v;}L[N];
struct BIT
{
    int c[N];
    int lowbit(int x){return x&(-x);}
    void build(){for(int i=1;i<=n;i++) c[i]=0;}
    void update(int x,int k){while(x<=n) c[x]=max(c[x],k),x+=lowbit(x);}
    int query(int x)
    {
        int res=0;
        while(x) res=max(res,c[x]),x-=lowbit(x);
        return res;
    }
}tr;
bool cmp1(edge x,edge y){return x.w<y.w;}
bool cmp2(edge x,edge y){return x.u>y.u;}
bool cmp3(edge x,edge y){return x.u<y.u;}
bool cmpx(line x,line y){return x.x>y.x;}
bool cmpy(line x,line y){return x.x<y.x;}
void dfs(int u)
{
    dfn[u]=++tot;
    sz[u]=1;
    for(auto it=g[u].begin();it!=g[u].end();it++)
    {
        int v=(*it).first,w=(*it).second;
        if(v==fa[0][u]) continue;
        fa[0][v]=u,f[v]=w,dep[v]=dep[u]+1,mx[v]=max(mx[u],w);
        dfs(v);
        sz[u]+=sz[v];
    }
} 
bool flag;
void kru()
{
    sort(e+1,e+1+m,cmp1);
    for(int i=1;i<=n;i++) g[i].clear();
    d.init();
    cnt=sum=0;
    for(int i=1;i<=m;i++)
    {
        int u=e[i].u,v=e[i].v,w=e[i].w;
        int nu=d.findfa(u);
        int nv=d.findfa(v);
        if(nu==nv)
        {
            ne[++cnt]=e[i];
            if(i<n) flag=true;
            continue;
        }
        d.fa[nv]=nu;
        sum+=w;
        g[u].push_back(make_pair(v,w));
        g[v].push_back(make_pair(u,w));
    }
}
int lca(int x,int y)
{
    if(dep[x]>dep[y]) swap(x,y);
    for(int i=17;i>=0;i--)
        if(dep[fa[i][y]]>=dep[x]) y=fa[i][y];
    if(x==y) return x;
    for(int i=17;i>=0;i--)
        if(fa[i][x]!=fa[i][y]) x=fa[i][x],y=fa[i][y];
    return fa[0][x];
}
signed main()
{
    T=read();
    while(T--)
    {
        n=read();
        m=read();
        for(int i=1;i<=m;i++) e[i].u=read(),e[i].v=read(),e[i].w=read();
        flag=false;
        kru();
        if(m-cnt<n-1||flag)
        {
            if(m<n-1) puts("-1");
            else
            {
                int ans=0;
                for(int i=1;i<=n-1;i++) ans+=e[i].w;
                print(ans);
                putchar('\n');
            }
            continue;
        }
        int now=1,ans=inf;
        if(m>=n+1) ans=sum-e[n-2].w-e[n-1].w+e[n].w+e[n+1].w;
        tot=0;
        dep[1]=1;
        dfs(1);
        for(int i=1;i<=17;i++)
            for(int j=1;j<=n;j++) fa[i][j]=fa[i-1][fa[i-1][j]];
        for(int i=1;i<=cnt;i++)
        {
            int x=mx[lca(ne[i].u,ne[i].v)];
            if(x) ans=min(ans,sum-x+ne[i].w);
            ne[i].u=dfn[ne[i].u];
            ne[i].v=dfn[ne[i].v];
            if(ne[i].u<ne[i].v) swap(ne[i].u,ne[i].v);
        }
        ct=0;
        for(int i=2;i<=n;i++)
        {
            int x=dfn[i]-1;
            L[++ct]=(line){x,n-x+1,f[i]};
        }
        sort(L+1,L+1+ct,cmpx);
        sort(ne+1,ne+1+cnt,cmp2);
        tr.build();
        for(int i=1;i<=cnt;i++)
        {
            int u=ne[i].u,v=ne[i].v,w=ne[i].w;
            while(now<=ct&&L[now].x>=u) tr.update(L[now].y,L[now].v),now++;
            int x=tr.query(n-v+1);
            if(!x) continue;
            ans=min(ans,sum-x+w);
        }
        ct=0;
        for(int i=2;i<=n;i++)
        {
            int x=dfn[i]+sz[i];
            L[++ct]=(line){x,x,f[i]};
        }
        sort(L+1,L+1+ct,cmpy);
        sort(ne+1,ne+1+cnt,cmp3);
        tr.build();
        now=1;
        for(int i=1;i<=cnt;i++)
        {
            int u=ne[i].u,v=ne[i].v,w=ne[i].w;
            while(now<=ct&&L[now].x<=u) tr.update(L[now].y,L[now].v),now++;
            int x=tr.query(v);
            if(!x) continue;
            ans=min(ans,sum-x+w);
        }
        ct=0;
        for(int i=2;i<=n;i++)
        {
            int x=dfn[i]-1,y=dfn[i]+sz[i];
            L[++ct]=(line){y,n-x+1,f[i]};
        }
        sort(L+1,L+1+ct,cmpy);
        tr.build();
        now=1;
        for(int i=1;i<=cnt;i++)
        {
            int u=ne[i].u,v=ne[i].v,w=ne[i].w;
            while(now<=ct&&L[now].x<=u) tr.update(L[now].y,L[now].v),now++;
            int x=tr.query(n-v+1);
            if(!x) continue;
            ans=min(ans,sum-x+w);
        }
        if(ans==inf) ans=-1;
        print(ans);
        putchar('\n');
    }
    return 0;
}