Nephren 题解

· · 个人记录

我们先考虑已经切断若干边后的情况,这时树会分裂成若干子树。

每个子树也能类似地定义其战斗力为子树内防守区块战斗力之和,而基地的总战斗力就是每个子树的战斗力之和。

我们可以发现,每个子树 A=(V_A,E_A) 的战斗力为 |V_A|C(A),其中 C(A) 是子树 A 的防守区块个数。

由期望的线性性,考虑对每个子树单独计算战斗力期望。

\mathbb E[F(T)] = \sum\limits_A\mathbb E[F(A)]=\sum\limits_A|V_A|\mathbb E[C(A)]

于是我们需要考虑如何计算子树 A 的防守区块个数的期望。

这里有一个很巧妙的结论:

首先在一个有 n 个节点的森林 T 上,C(T)=|V_T|-|E_T|,这个是显然的。

于是我们就能根据期望的性质得到:

\mathbb E[C(T)]&=\mathbb E(|V_T|)-\mathbb E(|E_T|)\\ &=\sum\limits_{v\in V_T}x(v)-\sum\limits_{e\in E_T}x(e)\\ (x&\text{指点被选中或边两端的点都被选中的期望})\\ &=\sum\limits_{v\in V_T}\frac{1}{2}-\sum\limits_{e\in E_T}\frac{1}{4}\\ &=\frac{n}{2}-\frac{n-1}{4}\\ &=\frac{n+1}{4} \end{aligned}

于是我们发现 \mathbb E[C(T)]T 的形态无关,而且是关于 n 的多项式。

所以我们有

\mathbb E[F(T)]=\sum\limits_i\frac{a_i(a_i+1)}{4}

其中 a_i 是第 i 个连通块的大小。

现在我们考虑怎么求对于所有分割方案,上述式子的期望。

我们可以退而求其次,考虑求解上述式子对于所有分割方案的和,然后除以总方案数得到期望。

于是我们可以得到一个显然的 dp:设 f_{u,i} 是以 u 为根节点的大小为 i 的连通块的答案。我们可以在子树上逐个转移,这个复杂度是 O(n^2) 的,可以得到 49 分。

如果要做到线性,我们需要用一个树上连通块 dp 的组合意义优化:

考虑将上式改写为:

\mathbb E[F(T)]=\sum_i\frac{1}{2}\mathbf C_{a_i}^2+\frac{1}{2}\mathbf C_{a_i}^1

于是我们可以通过组合意义写出定义 dp 方程:

对于第一部分 \mathbf C_{a_i}^2,我们定义 f_{u,0/1/2} 为在以 u 为根的连通块中选择 0/1/2 个点的方案数。对于第二部分也可以有类似的定义 g_{u,0/1}

每个连通块 A 对答案的贡献次数即为连通块外的切分方案数,即 2^{n-|A|-1} 次。显然以 u 为根的所有连通块在子树 u 外的切分方案都是相同的,因此我们考虑子树内和子树外分别统计贡献:

子树内:

对于 u 的一个儿子 v,如果 vu 在同一连通块内,那么直接组合计算方案数;

如果不在同一连通块内,那么 u 就会对答案贡献 2^{size_v-1} 次,直接乘在 dp 状态上。

子树外:

可以发现子树外的贡献对于转移过程没有影响,因此统计答案时乘上即可。

最终的答案为 \dfrac 1 {2^{n-1}}\cdot\dfrac1 2\sum\limits_u 2^{n-size_u-1}(f_{u,2}+g_{u,1})

拓展

通过这种方式,我们可以在 O(nk^2) 的时间内计算出所有代价函数是关于连通块大小的 k 次多项式的树上分割计数问题。

代码

#include <bits/stdc++.h>
using namespace std;
#define N 1000005
int n,hd[N],t,d[N],p[N];
typedef long long ll;
ll f[N][3],g[N][2],sz[N],ans,pow2[N];
#define MOD 998244353

struct edge{
    int u,v,nxt;
}es[2*N];

void add_edge(int u,int v){
    es[++t]=(edge){u,v,hd[u]};
    hd[u]=t;
}

ll fastpow(ll x,int p){
    ll res=1;
    while(p>0){
        if(p&1)res = res*x%MOD;
        p>>=1;
        x=x*x%MOD;
    }
    return res;
}

void dfs(int u,int fa){
    f[u][0]=1;f[u][1]=1;
    g[u][0]=1;g[u][1]=1;
    sz[u]=1;
    int v;
    for(int i=hd[u];i;i=es[i].nxt){
        v = es[i].v;
        if(v==fa)continue;
        dfs(v,u);
        sz[u]+=sz[v];
        //C(k,2)
        f[u][2]=(f[u][2]*f[v][0]+f[u][1]*f[v][1]+f[u][0]*f[v][2]//同一连通块
            +f[u][2]*pow2[sz[v]-1])%MOD;//不同连通块

        f[u][1]=(f[u][1]*f[v][0]+f[u][0]*f[v][1]//同一连通块
            +f[u][1]*pow2[sz[v]-1])%MOD;//不同联通块

        f[u][0]=(f[u][0]*f[v][0]//同一连通块
            +f[u][0]*pow2[sz[v]-1])%MOD;//不同连通块 

        //C(k,1)
        g[u][1]=(g[u][1]*g[v][0]+g[u][0]*g[v][1]//同一连通块
            +g[u][1]*pow2[sz[v]-1])%MOD;//不同连通块

        g[u][0]=(g[u][0]*g[v][0]//同一连通块
            +g[u][0]*pow2[sz[v]-1])%MOD;//不同连通块
    }
    ans = (ans+f[u][2]*pow2[(n-sz[u]-1)>=0?n-sz[u]-1:0]+g[u][1]*pow2[(n-sz[u]-1)>=0?n-sz[u]-1:0])%MOD;
}

ll rev(ll x){
    return fastpow(x,MOD-2);
}

int main(){
    cin >> n;
    int u,v;
    pow2[0]=1;
    for(int i=1;i<=n;i++)pow2[i]=pow2[i-1]*2%MOD;
    for(int i=1;i<n;i++){
        scanf("%d%d",&u,&v);
        add_edge(u,v);
        add_edge(v,u);
    }

    dfs(1,0);
    cout << ans << endl;
    printf("%lld",(ans*rev(pow2[n])%MOD));
    return 0;
}

后记

在出数据的时候爆了本地栈qwq