P3806 【模板】点分治1

· · 个人记录

题目

题意:给定一棵带权无根树,问是否有点对的距离为 k

假如我们要遍历树上所有的点对,最朴素的想法是枚举端点,对每个点来一趟 dfs ,复杂度 \mathcal{O}(n^2) 为了效率更高地解决问题,我们引入分治思想。对于每个点,我们分别考虑包含这个点的路径和不包含这个点的路径。对于前者,我们做一趟 dfs;对于后者,我们删除该点后,对所有子树递归地处理即可。

但是如果直接做,复杂度是不稳定的,例如当树是一条链时,有可能会退化成 \mathcal{O}(n^2) 。然而,如果每次都选择子树的重心,那么复杂度就可以保证为 \mathcal{O}(n \log n) 。因为重心的子树大小不超过 \frac{n}{2} 所以每次递归问题规模可以下降一半或以上。这种做法就叫做点分治。

找重心的话,可以用一次 dfs

//sz[x]表示以x为根的子树大小,maxp[x]表示x的最大子树
void dfs(int x,int fa,int total){  //找重心 
    sz[x]=1,maxp[x]=0;
    for(int i=Head[x];i;i=Next[i]){
        int y=to[i];
        if(y==fa||vis[y]) continue;
        dfs(y,x,total);
        sz[x]+=sz[y];
        maxp[x]=max(maxp[x],sz[y]);
    }
    maxp[x]=max(maxp[x],total-sz[x]);//还有一棵以其父亲节点为根的子树
    if(!ctr||maxp[x]<maxp[ctr]) ctr=x;//找到最优的根
}

找到重心后,我们需要不断分治,在分治过程中也需要不断找重心来优化,因为在找到重心后,其实子树size会发生变化,所以最好再以重心为根再dfs一遍 ,其实不再 dfs 的复杂度也是对的,可以看这里的证明:

void solve(int x){   //分治 
    vis[x]=1,calc(x);
    for(int i=Head[x];i;i=Next[i]){
        int y=to[i];
        if(vis[y]) continue;
        ctr=0,dfs(y,-1,sz[y]),dfs(ctr,-1,sz[y]); 
        solve(ctr);
    }
}

由于我们保证了所有的路径都是第一种(经过根节点的路径),所以我们对于每一个根,可以先预处理出每一个子节点到根的距离,这样我们就可以得到对于每一个点可能出现的距离

记当前的重心为 ctr

void gdis(int x,int fa,int dis,int from){
    a[++tot]=x,d[x]=dis,b[x]=from;
    for(int i=Head[x];i;i=Next[i]){
        int y=to[i];
        if(y==fa||vis[y]) continue;
        gdis(y,x,dis+edge[i],from);
    } 
}

然后需要将 a 数组按 d 的大小来排序

bool cmp(int x,int y){
    return d[x]<d[y];
}

最后利用双指针 l , r,把任意两个出现的距离凑在一起,并判断可否凑出我们需要的 k 即可(注意复原的时候不要用 memset ,将 tot 清零即可)

void calc(int x){
    tot=0,a[++tot]=x;  //初始化 
    d[x]=0,b[x]=x;
    for(int i=Head[x];i;i=Next[i]){  //遍历 x 的所有儿子 
        int y=to[i];
        if(vis[y]) continue;
        gdis(y,x,edge[i],y);
    }
    sort(a+1,a+1+tot,cmp);  //排序 
    for(int i=1;i<=m;i++){
        int l=1,r=tot;  
        if(ok[i]) continue;
        while(l<r){
            if(d[a[l]]+d[a[r]]>q[i]) r--;   //当和比询问的长度大时,右指针左移
            else if(d[a[l]]+d[a[r]]<q[i]) l++;  //同上 
            else if(b[a[l]]==b[a[r]]){  //和为询问的长度,但同属一棵子树,继续下一种情况
                if(d[a[r]]==d[a[r-1]]) r--;
                else l++;
            }else{
                ok[i]=1;
                break;
            }
        } 
    }
}

复杂度 \mathcal{O}(n \log^{2}n+n m\log n)

code

#include<bits/stdc++.h>
#define N 10005
using namespace std;
int read(){
    int x=0,f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-f;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
    return x*f;
} //快读 
int Head[N],Next[N<<1],to[N<<1],edge[N<<1];   //链式前向星 
bool vis[N],ok[N];
int n,m,tot,ctr,maxp[N];
int sz[N],q[N],a[N],b[N],d[N];
void add(int u,int v,int w){
    to[++tot]=v,Next[tot]=Head[u],Head[u]=tot,edge[tot]=w;
}
bool cmp(int x,int y){
    return d[x]<d[y];
}
void dfs(int x,int fa,int total){  //找重心 
    sz[x]=1,maxp[x]=0;
    for(int i=Head[x];i;i=Next[i]){
        int y=to[i];
        if(y==fa||vis[y]) continue;
        dfs(y,x,total);
        sz[x]+=sz[y];
        maxp[x]=max(maxp[x],sz[y]);
    }
    maxp[x]=max(maxp[x],total-sz[x]);
    if(!ctr||maxp[x]<maxp[ctr]) ctr=x;
}
void gdis(int x,int fa,int dis,int from){
    a[++tot]=x,d[x]=dis,b[x]=from;
    for(int i=Head[x];i;i=Next[i]){
        int y=to[i];
        if(y==fa||vis[y]) continue;
        gdis(y,x,dis+edge[i],from);
    } 
}
void calc(int x){
    tot=0,a[++tot]=x; 
    d[x]=0,b[x]=x;
    for(int i=Head[x];i;i=Next[i]){  
        int y=to[i];
        if(vis[y]) continue;
        gdis(y,x,edge[i],y);
    }
    sort(a+1,a+1+tot,cmp);  
    for(int i=1;i<=m;i++){
        int l=1,r=tot;  
        if(ok[i]) continue;
        while(l<r){
            if(d[a[l]]+d[a[r]]>q[i]) r--;   
            else if(d[a[l]]+d[a[r]]<q[i]) l++;  
            else if(b[a[l]]==b[a[r]]){  
                if(d[a[r]]==d[a[r-1]]) r--;
                else l++;
            }else{
                ok[i]=1;
                break;
            }
        } 
    }
}
void solve(int x){    
    vis[x]=1,calc(x);
    for(int i=Head[x];i;i=Next[i]){
        int y=to[i];
        if(vis[y]) continue;
        ctr=0,dfs(y,-1,sz[y]),dfs(ctr,-1,sz[y]);
        solve(ctr);
    }
}
int main(){
    n=read(),m=read();
    for(int i=1;i<n;i++){
        int u=read(),v=read(),w=read();
        add(u,v,w),add(v,u,w);
    }
    dfs(1,-1,n),dfs(ctr,-1,n);   
    for(int i=1;i<=m;i++){
        q[i]=read();
        if(!q[i]) ok[i]=1;   //特判
    } 
    solve(ctr);
    for(int i=1;i<=m;i++){
        if(ok[i]) printf("AYE\n");
        else printf("NAY\n");
    }
    return 0;
} 

参考资料:算法学习笔记(73): 点分治,P3806 【模板】点分治1题解