【学习笔记】Prüfer 序列

· · 算法·理论

定义

Prüfer 序列能够将一个有 n 个节点的树用 [1,n] 中的 n-2 个元素表示,每一个 Prüfer 序列对应唯一一颗树,每一颗树对应唯一一个 Prüfer 序列。

如何对于一棵树建立它的 Prüfer 序列?每次选择该树最小的叶子结点并删掉它,然后在序列的末尾把它所连接的节点加入,知道只剩 2 个节点,此时 Prüfer 序列中有 n-2 个元素。

实现

首先考虑暴力实现,显然有一个 O(n^2) 的算法可以实现,具体的,枚举最小的没被删除的叶子结点并删除。

然后考虑优化,不难发现用堆即可优化到 O(n\log n)

但是,Prüfer 序列是可以线性构造的,具体流程如下:

维护一个指针 p,初始时 p 指向编号最小的叶子结点,同时我们维护每个节点的出度 cnt 以及每个节点的父亲,每次删除一个节点 u 时,将它的父亲 fa 加入到序列中,同时 cnt_{fa}=cnt_{fa}-1,然后判断若 cnt_{fa}=0fa<uu=fa 然后继续,否则就 u 递加,可以证明此方法正确。

不难发现编号最大的节点一定不会被删除,因为一个点数多于 2 的树必然不止一个叶子结点。故我们在对一棵无根树建立 Prüfer 序列时可以将最大的节点作为根。

解决的问题

Prüfer 序列可以解决许多与生成树计数有关的题目,每有一个 Prüfer 序列就有一种生成树的方案。

例题

【模板】Prufer 序列

板子,按照上面的线性建 Prüfer 序列的方法即可,但是还需要将 Prüfer 序列转化为树,方法其实与建立 Prüfer 序列差不多,故不在赘述。

由于题目说以 n 为根,所以可以不用建树。

代码如下:

#include<bits/stdc++.h>
using namespace std;
const int N=1e7+5;
int fa[N],cnt[N],p[N];
signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    int n,op;
    cin>>n>>op;
    long long ans=0;
    if(op==1)
    {
        for(int i=1;i<n;i++)cin>>fa[i],cnt[fa[i]]++;
        for(int i=1,j=1;i<=n-2;i++,j++)
        {
            while(cnt[j]!=0)j++;
            p[i]=fa[j];
            while(i<n-2&&--cnt[p[i]]==0&&p[i]<j)p[i+1]=fa[p[i]],i++;
        }
        for(int i=1;i<=n-2;i++)ans^=1ll*i*p[i];
    }
    else
    {
        for(int i=1;i<=n-2;i++)cin>>p[i],cnt[p[i]]++;
        p[n-1]=n;
        for(int i=1,j=1;i<n;i++,j++)
        {
            while(cnt[j]!=0)j++;
            fa[j]=p[i];
            while(i<n-1&&--cnt[p[i]]==0&&p[i]<j)fa[p[i]]=p[i+1],i++;
        }
        for(int i=1;i<n;i++)ans^=1ll*i*fa[i];
    }
    cout<<ans;
    return 0;
}

[HNOI2004] 树的计数

根据 Prüfer 序列的性质,每个 Prüfer 序列对应一种生成树的方案,故考虑有几种 Prüfer 序列。

题目给定每个节点的度数,即规定了每个节点在 Prüfer 序列出现的次数,算个多重组合数即为答案。设答案为 ans,则有:

ans=\frac{(n-2)!}{\prod_{i=1}^n(d_i-1)!}

上式中 n-2 为 Prüfer 序列的长度,d_i-1 为每个节点出现的个数。

需要注意的是,会出现无解的情况,所有情况如下:

代码如下:

#include<bits/stdc++.h>
using namespace std;
const int N=205;
int cnt[N];
void solve(int n,int val)
{
    for(int i=1;i<=n;i++)
    {
        int tmp=i;
        for(int j=2;j*j<=tmp;j++)
            while(tmp%j==0)cnt[j]+=val,tmp/=j;
        if(tmp>1)cnt[tmp]+=val;
    }
}
int a[N];
signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    int n,sum=0,c=0;
    cin>>n;
    for(int i=1;i<=n;i++)cin>>a[i],sum+=a[i],c+=(a[i]==1?1:0);
    if(sum!=2*(n-1)||n!=1&&c<2||n==1&&a[1]!=0)
    {
        cout<<0;
        return 0;
    }
    solve(n-2,1);
    for(int i=1;i<=n;i++)solve(a[i]-1,-1);
    long long ans=1;
    for(int i=1;i<=150;i++)
    for(int j=1;j<=cnt[i];j++)
        ans*=i;
    cout<<ans;
    return 0;
}