题解:P13310 染紫
前言
好困难的计数题。
Solution
考虑组合意义,平方转化成从连通块内选有序点对,变成对每种染色方案计数四元组
但这仍然是不好算的,考虑反过来计算每个四元组在多少中染色方案中合法。
但这仍然是不好算的,考虑拆开红点和蓝点的贡献,分别计算后二者相乘即可。
但这仍然是不好算的,考虑将计数转成概率。
以计算红点路径为例。
考虑对红点赋权
考虑
转移是简单的:
蓝色记为
答案是
但是注意到这样的答案是错的,因为可能有两条路径有交的情况,这种情况显然不应该被计数。
考虑容斥,我们已经计算出来了所有情况,减去红点 LCA 在蓝色路径上、蓝点 LCA 在红色路径上,加上二者 LCA 重合的情况即可。
第三种利用刚才处理的东西是可以算出来的。
我们观察到,
但是我们需要每个点的这种信息,因此我们做一遍换根 DP 即可,将得到的信息记为
剩下就是好统计的了,
Code
:::success[代码]
#include<bits/stdc++.h>
#define inf 0x3f3f3f3f
#define infll 0x3f3f3f3f3f3f3f3fll
using namespace std;
static const unsigned long long mod=998244353;
struct modint{
// ...
};
static const modint _2=((modint)(2)).getinv();
int n;
modint Pr[2000010],Pb[2000010];
modint ans=0;
vector<int> g[2000010];
modint dp_ch_r[2000010],dp_r[2000010];
void dfs1(int x,int p=0){
for(int y:g[x]) if(y!=p) dfs1(y,x),dp_ch_r[x]+=dp_ch_r[y];
dp_ch_r[x]=(dp_ch_r[x]*Pr[x]+Pr[x]);
dp_r[x]=dp_ch_r[x];
for(int y:g[x]) if(y!=p) dp_r[x]+=(dp_ch_r[x]-Pr[x]*dp_ch_r[y])*dp_ch_r[y];
}
modint dp_ch_b[2000010],dp_b[2000010];
void dfs2(int x,int p=0){
for(int y:g[x]) if(y!=p) dfs2(y,x),dp_ch_b[x]+=dp_ch_b[y];
dp_ch_b[x]=(dp_ch_b[x]*Pb[x]+Pb[x]);
dp_b[x]=dp_ch_b[x];
for(int y:g[x]) if(y!=p) dp_b[x]+=(dp_ch_b[x]-Pb[x]*dp_ch_b[y])*dp_ch_b[y];
}
modint ndp_r[2000010],ndp_b[2000010];
void dfs3(int x,int p=0){
ans-=(dp_r[x]*ndp_b[x])+(dp_b[x]*ndp_r[x]);
for(int y:g[x]) if(y!=p){
ndp_r[y]=dp_r[y]+(2*dp_ch_r[y]*(dp_ch_r[x]-Pr[x]*dp_ch_r[y]));
ndp_b[y]=dp_b[y]+(2*dp_ch_b[y]*(dp_ch_b[x]-Pb[x]*dp_ch_b[y]));
dp_ch_r[y]+=(dp_ch_r[x]-Pr[x]*dp_ch_r[y])*Pr[y];
dp_ch_b[y]+=(dp_ch_b[x]-Pb[x]*dp_ch_b[y])*Pb[y];
dfs3(y,x);
}
}
int main(){
cin.tie(0)->sync_with_stdio(false);
cin>>n;
for(int i=1;i<n;i++){
int x,y;cin>>x>>y;
g[x].push_back(y),g[y].push_back(x);
}
int cnt=0;
for(int i=1;i<=n;i++){
char c;cin>>c;
if(c=='r') Pr[i]=1,Pb[i]=0;
if(c=='b') Pr[i]=0,Pb[i]=1;
if(c=='w') Pr[i]=_2,Pb[i]=_2,cnt++;
}
dfs1(1),dfs2(1);
modint rsum=0,bsum=0;
for(int i=1;i<=n;i++){
rsum+=dp_r[i],bsum+=dp_b[i];
ans+=(dp_r[i]*dp_b[i]);
}
ans+=(rsum*bsum);
ndp_r[1]=dp_r[1],ndp_b[1]=dp_b[1];
dfs3(1);
cout<<ans*(modint(2).qpow(cnt))<<"\n";
# ifdef KarmaticEnding
cerr<<"\n\033[1;38;2;234;200;225mUsed time: "<<clock()*1.0/CLOCKS_PER_SEC<<"s.\033[0m\n";
# endif
return 0;
}
:::