题解:CF1454E Number of Simple Paths

· · 题解

基环树,那么上套路。

我们先找到环,这里我使用了 dfs 找环。然后考虑对环上每个点的子树求贡献,然后再回到环上求整个贡献。

考虑如果是一棵树,怎么求贡献。显然是个树形 dp,不妨往下搜索,一个点的子树对其贡献有两种:单棵子树的路径,这种好统计,递归的时候加起来就行了:对于每一个儿子,贡献是这个儿子节点的答案加上这个子树的节点个数(因为每个其儿子节点及其子树中每个节点到这个点都会多出一条路径);还有一种是其各个儿子的子树之间的路径,根据乘法原理,一个子树中的节点到其他全部节点都会有等于两者个数乘积的贡献,记 node_i 表示 i 节点的子树的节点个数(包括这个节点本身),则这部分贡献为 \Large \frac{\sum node_u-1-node_v}{2}。把这两部分加起来就是一个点及其子树的贡献。

然后我们回到环上考虑。其实和前面说的跨环的贡献是差不多的,区别在于在环上有两个方向可以走,因此不用除以 2。

代码:

#include<iostream>
#include<vector>
#include<cstring>
using namespace std;
const int N=2e5+10; 
int n;
vector<int> e[N];
bool vis[N],flag,flagg,oncir[N];
int id,cnt;
int cir[N];
void get_cir(int u,int fa){  //找环
    vis[u]=1;
    for(int i=0;i<e[u].size();i++){
        int v=e[u][i];
        if(v==fa){
            continue;
        }
        if(vis[v]){
            id=v;
            flag=1;
            oncir[v]=1;
            cir[++cnt]=v;
            return;
        } 
        get_cir(v,u);
        if(flagg){
            return;
        }
        if(flag){
            cir[++cnt]=v;
            oncir[v]=1;
            if(u==id){
                flagg=1;
                return;
            }
            return;
        }
    }
}
long long node[N];
long long dfs(int u,int fa){  //统计u的贡献
    if(e[u].size()==1){  //叶子 
        node[u]=1;
        return 0;
    }
    long long ans=0,sum=0;  
    for(int i=0;i<e[u].size();i++){
        int v=e[u][i];
        if(v==fa||oncir[v]){  //不能跑到环上
            continue;
        }
        ans+=dfs(v,u)+node[v]; 
        node[u]+=node[v];
    }
    node[u]++;  //加上u本身
    for(int i=0;i<e[u].size();i++){
        int v=e[u][i];
        if(v==fa||oncir[v]){
            continue;
        }
        sum+=node[v]*(node[u]-1-node[v]);
    }
    return ans+sum/2;
}
int main(){
    int T;
    cin>>T;
    while(T--){
        cnt=0; 
        memset(vis,0,sizeof vis); 
        memset(cir,0,sizeof cir);
        memset(oncir,0,sizeof oncir); 
        memset(node,0,sizeof node);
        id=flag=flagg=0;
        cin>>n;
        for(int i=1;i<=n;i++){
            int x,y;
            cin>>x>>y;
            e[x].push_back(y);
            e[y].push_back(x); 
        }  
        get_cir(1,0);
        long long ans=0;
        for(int i=1;i<=cnt;i++){
            ans+=dfs(cir[i],0)+node[cir[i]]*(n-node[cir[i]]);
        } 
        cout<<ans<<"\n";
        for(int i=1;i<=n;i++){
            e[i].clear();
        } 
    } 
    return 0;
}