题解:P13310 染紫

· · 题解

前言

好困难的计数题。

Solution

考虑组合意义,平方转化成从连通块内选有序点对,变成对每种染色方案计数四元组 (r_1,r_2,b_1,b_2) 的数量,其中 r_1\to r_2 路径上全为红点、b_1\to b_2 路径上全为蓝点。

但这仍然是不好算的,考虑反过来计算每个四元组在多少中染色方案中合法。

但这仍然是不好算的,考虑拆开红点和蓝点的贡献,分别计算后二者相乘即可。

但这仍然是不好算的,考虑将计数转成概率。

以计算红点路径为例。

考虑对红点赋权 1,蓝点赋权 0,未染色点赋权 \frac{1}{2},则这条路径的概率就是路径上的点权之积。

考虑 \mathrm{fcr}_x 表示以 x 为上端点的红色路径点权积和,\mathrm{fr}_x 表示以 x 为 LCA 的红色路径点权积和。

转移是简单的:

\mathrm{fcr}_x = v_x+v_x\sum_{y\in \operatorname{ch}(x)} \mathrm{fcr}_y \mathrm{fr}_x = \mathrm{fcr}_x + \sum_{y\in \operatorname{ch}(x)} (\mathrm{fcr}_x-v_x\mathrm{fcr}_y)\mathrm{fcr}_y

蓝色记为 \mathrm{fcb}_x\mathrm{fb}_x,同理计算即可。

答案是 \sum_x \mathrm{fr}_x \times \sum_y \mathrm{fb}_y

但是注意到这样的答案是错的,因为可能有两条路径有交的情况,这种情况显然不应该被计数。

考虑容斥,我们已经计算出来了所有情况,减去红点 LCA 在蓝色路径上、蓝点 LCA 在红色路径上,加上二者 LCA 重合的情况即可。

第三种利用刚才处理的东西是可以算出来的。

我们观察到,\mathrm{fr}_{\mathrm{root}} 所代表的就是有多少条红色路径经过了根节点。

但是我们需要每个点的这种信息,因此我们做一遍换根 DP 即可,将得到的信息记为 \mathrm{nfr}\mathrm{nfb}

剩下就是好统计的了,\mathrm{nfr}_x \mathrm{fb}_x + \mathrm{fr}_x \mathrm{nfb}_x 就是在 x 处算重的答案。

Code

:::success[代码]

#include<bits/stdc++.h>
#define inf 0x3f3f3f3f
#define infll 0x3f3f3f3f3f3f3f3fll
using namespace std;

static const unsigned long long mod=998244353;
struct modint{
    // ...
};

static const modint _2=((modint)(2)).getinv();

int n;
modint Pr[2000010],Pb[2000010];
modint ans=0;

vector<int> g[2000010];

modint dp_ch_r[2000010],dp_r[2000010];
void dfs1(int x,int p=0){
    for(int y:g[x]) if(y!=p) dfs1(y,x),dp_ch_r[x]+=dp_ch_r[y];
    dp_ch_r[x]=(dp_ch_r[x]*Pr[x]+Pr[x]);
    dp_r[x]=dp_ch_r[x];
    for(int y:g[x]) if(y!=p) dp_r[x]+=(dp_ch_r[x]-Pr[x]*dp_ch_r[y])*dp_ch_r[y];
}

modint dp_ch_b[2000010],dp_b[2000010];
void dfs2(int x,int p=0){
    for(int y:g[x]) if(y!=p) dfs2(y,x),dp_ch_b[x]+=dp_ch_b[y];
    dp_ch_b[x]=(dp_ch_b[x]*Pb[x]+Pb[x]);
    dp_b[x]=dp_ch_b[x];
    for(int y:g[x]) if(y!=p) dp_b[x]+=(dp_ch_b[x]-Pb[x]*dp_ch_b[y])*dp_ch_b[y];
}

modint ndp_r[2000010],ndp_b[2000010];
void dfs3(int x,int p=0){
    ans-=(dp_r[x]*ndp_b[x])+(dp_b[x]*ndp_r[x]);
    for(int y:g[x]) if(y!=p){
        ndp_r[y]=dp_r[y]+(2*dp_ch_r[y]*(dp_ch_r[x]-Pr[x]*dp_ch_r[y]));
        ndp_b[y]=dp_b[y]+(2*dp_ch_b[y]*(dp_ch_b[x]-Pb[x]*dp_ch_b[y]));
        dp_ch_r[y]+=(dp_ch_r[x]-Pr[x]*dp_ch_r[y])*Pr[y];
        dp_ch_b[y]+=(dp_ch_b[x]-Pb[x]*dp_ch_b[y])*Pb[y];
        dfs3(y,x);
    }
}

int main(){
    cin.tie(0)->sync_with_stdio(false);

    cin>>n;
    for(int i=1;i<n;i++){
        int x,y;cin>>x>>y;
        g[x].push_back(y),g[y].push_back(x);
    }

    int cnt=0;
    for(int i=1;i<=n;i++){
        char c;cin>>c;
        if(c=='r') Pr[i]=1,Pb[i]=0;
        if(c=='b') Pr[i]=0,Pb[i]=1;
        if(c=='w') Pr[i]=_2,Pb[i]=_2,cnt++;
    }

    dfs1(1),dfs2(1);
    modint rsum=0,bsum=0;
    for(int i=1;i<=n;i++){
        rsum+=dp_r[i],bsum+=dp_b[i];
        ans+=(dp_r[i]*dp_b[i]);
    }
    ans+=(rsum*bsum);

    ndp_r[1]=dp_r[1],ndp_b[1]=dp_b[1];
    dfs3(1);
    cout<<ans*(modint(2).qpow(cnt))<<"\n";

    # ifdef KarmaticEnding
    cerr<<"\n\033[1;38;2;234;200;225mUsed time: "<<clock()*1.0/CLOCKS_PER_SEC<<"s.\033[0m\n";
    # endif
    return 0;
}

:::