虚树

· · 算法·理论

引入

[SDOI2011] 消耗战

简要题意:给定 n 个点的树,边权为 w,再给出 m 个询问, k 个关键点,求最小代价使所有给定点都和 1 号节点不连通。

暴力

一眼树形 dp 。设 dp_i 为点 i 不和子树中的任意一点连通的最小代价。

易得转移方程。考虑枚举儿子 v 。如果 i 不为关键点,则 dp_i=dp_i+\min(dp_v,w(i,v))。如果 i 是关键点,则 dp_i=dp_v+w(i,v)

复杂度 O(nm) 显然过不了。

概念

发现有很多的点没有用,只有关键点和他们的 lca 有用。

考虑将关键点和它们两两的 lca 连边,建成一棵树,那么这棵树就是虚树。

构造过程

  1. 排序去重法(警示:需特判一个关键点的情况

在关键点序列上,枚举相邻的两个数,两两求得 lca 并且加入序列 A 中。

因为 DFS 序的性质,此时的序列 A 已经包含了虚树中的所有点,但是可能有重复。

所以我们把序列 A 按照 dfn 序从小到大排序并且去重。

最后,在序列 A 上,枚举相邻的两个数 x,y,求得它们的 lca 并且连接 lca,y,虚树就构造完成了。

代码:

len=m;
sort(h+1,h+m+1,cmp);
f(i,2,len)
    h[++m]=lca(h[i],h[i-1]);
sort(h+1,h+m+1,cmp);
m=unique(h+1,h+m+1)-h-1;
f(i,2,m)
    add(lca(h[i],h[i-1]),h[i]);
  1. 单调栈法

用单调栈维护树上的一条链,其中栈里相邻的两个节点在虚树上是相邻的,而且栈从底部到栈首的 dfs 序是单调递增的。

首先将节点 1 入栈。

接下来按照 DFS 序从小到大添加关键节点。

如果当前节点与栈顶节点的 LCA 就是栈顶节点的话,则说明它们是在一条链上的。所以直接把当前节点入栈就行了。

如果当前节点与栈顶节点的 LCA 不是栈顶节点的话,弹栈,直到次大节点的 Dfs 序小于等于 LCA 的 DFS 序。

具体看代码。

本题解法

建立虚树后,跑树形 dp 即可。

代码

单调栈:

#include<bits/stdc++.h>
#define f(i,l,r) for(int i=l;i<=r;++i)
#define F(i,r,l) for(int i=r;i>=l;--i)
#define int long long
#define ULL unsigned long long
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
#define read(n) {int _x=0,_ty=1;char _c=getchar();while(!isdigit(_c)){if(_c=='-')_ty=-1;_c=getchar();}while(isdigit(_c))_x=10*_x+_c-'0',_c=getchar();n=_x*_ty;}
char buf[1<<21],*p1=buf,*p2=buf;
using namespace std;
const int N=250005; 
int n,m,q,k[N],tot,head[N],nxt[N<<1],to[N<<1],val[N<<1];
int siz[N],top[N],dfn[N],son[N],fa[N],dep[N],mn[N],cnt;
int st[N],tp;
vector<int> v[N];
void add(int u,int v,int w){
    nxt[++tot]=head[u];
    head[u]=tot;
    to[tot]=v;
    val[tot]=w;
}
void dfs1(int u,int fath){
    siz[u]=1;
    fa[u]=fath;
    dep[u]=dep[fath]+1;
    for(int i=head[u];i;i=nxt[i]){
        int v=to[i];
        if(v!=fath){
            mn[v]=min(mn[u],val[i]);
            dfs1(v,u);
            siz[u]+=siz[v];
            if(siz[son[u]]<siz[v])
                son[u]=v; 
        }
    }
}
void dfs2(int u,int tp){
    top[u]=tp;
    dfn[u]=++cnt;
    if(son[u])
        dfs2(son[u],tp);
    for(int i=head[u];i;i=nxt[i]){
        int v=to[i];
        if(v!=fa[u]&&v!=son[u])
            dfs2(v,v);
    }
}
int Lca(int u,int v){
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]])
            swap(u,v);
        u=fa[top[u]];
    }
    if(dep[u]>dep[v])
        swap(u,v);
    return u;
}
bool cmp(int a,int b){
    return dfn[a]<dfn[b];
}
int dfs(int u){
    if(v[u].empty())
        return mn[u];
    int ans=0;
    for(int i=0;i<v[u].size();++i)
        ans+=dfs(v[u][i]);
    v[u].clear();
    return min(mn[u],ans); 
}
signed main(){
    read(n);
    f(i,1,n-1){
        int u,v,w;
        read(u);
        read(v);
        read(w);
        add(u,v,w);
        add(v,u,w);
    }
    mn[1]=LLONG_MAX;
    dfs1(1,0);
    dfs2(1,1);
    read(q);
    while(q--){
        read(m);
        f(i,1,m)
            read(k[i]);
        sort(k+1,k+m+1,cmp);
        st[tp=1]=1;
        f(i,1,m){
            if(tp==1){
                st[++tp]=k[i];
                continue;
            }
            int lca=Lca(k[i],st[tp]);
            if(lca==st[tp])
                continue;
            while(tp&&dfn[st[tp-1]]>=dfn[lca]){
                v[st[tp-1]].push_back(st[tp]);
                --tp;
            }
            if(st[tp]!=lca){
                v[lca].push_back(st[tp]);
                st[tp]=lca;
            }
            st[++tp]=k[i];
        }
        while(tp){
            v[st[tp-1]].push_back(st[tp]);
            --tp;
        }
        printf("%lld\n",dfs(1));
    }
    return 0;
}

双排序:

#include<bits/stdc++.h>
#define f(i,l,r) for(int i=l;i<=r;++i)
#define F(i,r,l) for(int i=r;i>=l;--i)
#define int long long
#define ULL unsigned long long
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
#define read(n) {int _x=0,_ty=1;char _c=getchar();while(!isdigit(_c)){if(_c=='-')_ty=-1;_c=getchar();}while(isdigit(_c))_x=10*_x+_c-'0',_c=getchar();n=_x*_ty;}
char buf[1<<21],*p1=buf,*p2=buf;
using namespace std;
const int N=250005,M=500005; 
int n,m,q,k[N],h[N],tot,head[N],nxt[M],to[M],val[M],siz[N],top[N],dfn[N],son[N],fa[N],dep[N],mn[N],cnt;
vector<int> v[N];
bool vis[N];
void add(int u,int v,int w){
    nxt[++tot]=head[u];
    head[u]=tot;
    to[tot]=v;
    val[tot]=w;
}
void dfs1(int u,int fath){
    siz[u]=1;
    fa[u]=fath;
    dep[u]=dep[fath]+1;
    for(int i=head[u];i;i=nxt[i]){
        int v=to[i];
        if(v!=fath){
            mn[v]=min(mn[u],val[i]);
            dfs1(v,u);
            siz[u]+=siz[v];
            if(siz[son[u]]<siz[v])
                son[u]=v; 
        }
    }
}
void dfs2(int u,int tp){
    top[u]=tp;
    dfn[u]=++cnt;
    if(son[u])
        dfs2(son[u],tp);
    for(int i=head[u];i;i=nxt[i]){
        int v=to[i];
        if(v!=fa[u]&&v!=son[u])
            dfs2(v,v);
    }
}
int LCA(int u,int v){
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]])
            swap(u,v);
        u=fa[top[u]];
    }
    if(dep[u]>dep[v])
        swap(u,v);
    return u;
}
bool cmp(int a,int b){
    return dfn[a]<dfn[b];
}
int dfs(int u){
    if(v[u].empty())
        return mn[u];
    int ans=0;
    for(int i=0;i<v[u].size();++i)
        ans+=dfs(v[u][i]);
    v[u].clear();
    return vis[u]?mn[u]:min(mn[u],ans); 
}
signed main(){
    read(n);
    f(i,1,n-1){
        int u,v,w;
        read(u);read(v);read(w);
        add(u,v,w);
        add(v,u,w);
    }
    mn[1]=LLONG_MAX;
    dfs1(1,0);
    dfs2(1,1);
    read(q);
    while(q--){
        read(m);
        f(i,1,m){
            read(h[i]=k[i]);
            vis[k[i]]=1;
        }
        int len=m,mndep=LLONG_MAX,rt=h[1];
        sort(h+1,h+m+1,cmp);
        f(i,2,len)
            h[++m]=LCA(h[i],h[i-1]);
        sort(h+1,h+m+1,cmp);
        m=unique(h+1,h+m+1)-h-1;
        f(i,2,m){
            int lca=LCA(h[i],h[i-1]);
            if(dep[lca]<mndep)
                mndep=dep[lca],rt=lca;
            v[lca].push_back(h[i]);
        }
        printf("%lld\n",dfs(rt));
        f(i,1,len)
            vis[k[i]]=0;
    }
    return 0;
}