虚树

· · 算法·理论

虚树

问题导入

CF613D Kingdom and its Cities

题面

给定 n 个点的树,m 次询问,每次有 k 个点,删除其他的点使得这 k 给点互不连通,求最少删多少个点?

#### 思路 若直接 $n$ 个点跑,复杂度为 $O(nm)$ 直接爆炸,所以我们可以设想把 $k$ 个点先提取出来,然后跑 $dp

虚树引入

为了解决这个问题,我们可以用虚树

定义

虚树是由查询点两两查询点的lca还有根节点组成

虚树不保证原树的结构,但是继承了原树节点的先后顺序

举个栗子(红色是查询点)

那么根据虚树的定义,这棵树的虚树就是

这样就可以减少树形dp的负担

应用

虚树一般是 lca + 建虚树 + 树形dp

lca就是普通的lca,代码贴一下

void df(int x)
{
    siz[x]=1;
    dfn[x]=++C;
    dep[x]=dep[fa[x]]+1;
    for(node xx:p[x])
    {
        int v=xx.v;
        int w=xx.w;
        if(v==fa[x]) continue;
        fa[v]=x;
        df(v);
        siz[x]+=siz[v];
        if(son[x]==0||siz[son[x]]<siz[v]) son[x]=v;
    }
}
void ds(int x,int tp)
{
    top[x]=tp;
    if(son[x]==0) return;
    ds(son[x],tp);
    for(node xx:p[x])
    {
        int v=xx.v;
        if(v==fa[x]||v==son[x]) continue;
        ds(v,v);
    }
}
int lca(int x,int y)
{
    while(top[x]!=top[y])
    {
        if(dep[top[x]]>=dep[top[y]]) x=fa[top[x]];
        else y=fa[top[y]];
    }
    if(dep[x]<dep[y]) return x;
    return y;
}

然而怎么建虚树恁?

建虚树前,我们先对查询的点集 a 进行 dfn 的排序,保证建虚树时是一条链一条链的入栈

对于要入栈的点 a_i,设lca为 lca(a_i,s[top])

  1. lca = s[top] ,就表示 a_is[top] 是一条链上的,就直接加入栈中

  2. lca \ne s[top],就表示当前的链已经入完了,就要建边,如图

我们在出栈时,当dep[top-1]<dep[lca] 停止出栈

出栈时,就把s[top-1]s[top]连边(注意时有向边)

出完栈后,判断s[top]是否为lca

  1. 若为 lca ,a_i可以直接入栈

  2. 否则 s[top]肯定是lca的子树,连边,然后加入lcaa_i

代码

void build()
{
    sort(a+1,a+1+m,cmp);
    int cnt=1;
    s[cnt]=1;
    s[++cnt]=a[1];
    for(int i=2;i<=m;i++)
    {
        int l=lca(a[i],s[cnt]);
        while(cnt>1&&dep[s[cnt-1]]>=dep[l]) 
        {
            q[s[cnt-1]].push_back(s[cnt]); cnt--;
        }
        if(l!=s[cnt])
        {
            q[l].push_back(s[cnt]); s[cnt]=l;
        }
        s[++cnt]=a[i];
    }
    while(cnt)
    {
        q[s[cnt-1]].push_back(s[cnt]);
        cnt--;
    }
}

建完图后就可以进行树形dp了,本题的代码我就不贴了

P2495 [SDOI2011] 消耗战 /【模板】虚树

题面

给定 n 个点的树,边有点权,m 次询问,每次有 k 个点,删除边使得这 k 给点与根不连通,求最少代价?

思路

不难想到用虚树

但在建树的过程中,我们可以剪枝:

我们可以将一个查询点的子树全部去掉,不难证明

我们可以在dfs的时候预处理从根到当前的最小边权,跑树形dp时也很简单

code

#include<bits/stdc++.h>
using namespace std;
#define int long long
int n,m;
int T;
struct node
{
    int v,w;
};
vector<node> p[250001];
vector<int> q[250001];
int siz[250001];
int dep[250001];
int top[250001];
int son[250001];
int s[250001];
int a[250001];
int fa[250001];
int minn[250001];
int dfn[250001];
int h[250001];
int C;
const int inf=1e18;
void df(int x)
{
    siz[x]=1;
    dfn[x]=++C;
    dep[x]=dep[fa[x]]+1;
    for(node xx:p[x])
    {
        int v=xx.v;
        int w=xx.w;
        if(v==fa[x]) continue;
        fa[v]=x;
        minn[v]=min(minn[x],w);
        df(v);
        siz[x]+=siz[v];
        if(son[x]==0||siz[son[x]]<siz[v]) son[x]=v;
    }
}
void ds(int x,int tp)
{
    top[x]=tp;
    if(son[x]==0) return;
    ds(son[x],tp);
    for(node xx:p[x])
    {
        int v=xx.v;
        if(v==fa[x]||v==son[x]) continue;
        ds(v,v);
    }
}
int lca(int x,int y)
{
    while(top[x]!=top[y])
    {
        if(dep[top[x]]>=dep[top[y]]) x=fa[top[x]];
        else y=fa[top[y]];
    }
    if(dep[x]<dep[y]) return x;
    return y;
}
bool cmp(int x,int y)
{
    return dfn[x]<dfn[y];
}
void build()
{
    sort(a+1,a+1+m,cmp);
    int cnt=1;
    s[cnt]=1;
    s[++cnt]=a[1];
    for(int i=2;i<=m;i++)
    {
        int l=lca(a[i],s[cnt]);
        if(l==s[cnt]) continue;
        while(cnt>1&&dep[s[cnt-1]]>=dep[l]) 
        {
            q[s[cnt-1]].push_back(s[cnt]); cnt--;
        }
        if(l!=s[cnt])
        {
            q[l].push_back(s[cnt]); s[cnt]=l;
        }
        s[++cnt]=a[i];
    }
    while(cnt)
    {
        q[s[cnt-1]].push_back(s[cnt]);
        cnt--;
    }
}
int dp(int x)
{
    if(q[x].size()==0) return minn[x];
    int sum=0;
    for(int xx:q[x])
    {
        sum+=dp(xx);
    }
    vector<int>().swap(q[x]);
    return min(sum,minn[x]);
}
signed main()
{
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    cin>>n;
    for(int i=1;i<n;i++)
    {
        int u,v,w; cin>>u>>v>>w;
        p[u].push_back({v,w});
        p[v].push_back({u,w});
    }
    minn[1]=inf;
    dep[1]=1;
    df(1);
    ds(1,1);
    cin>>T;
    while(T--)
    {
        cin>>m;
        for(int i=1;i<=m;i++) cin>>a[i];
        build();
        cout<<dp(1)<<'\n';
    }
    return 0;
}

P3233 [HNOI2014] 世界树

题面

给定一个 n 个点的树,m 次询问:

每次询问给定 k 个点为管辖点,定义一个点 xy 管辖为:距离 x 最近且编号最小的管辖点 y

求这 k 个管辖点分别管辖多少个点?

1 \le n,m \le 3 \times 10^5,\sum k \le 3*10^5