题解:CF1823F Random Walk

· · 题解

集训出的神秘东西。

f_u 表示走到 u 的期望次数,d_u 为结点 u 的度。

显然每次走到 u,对于任意一个相邻的点,都有 \frac{1}{d_u} 的概率走到。

于是可以列出以下柿子。

其中 E 是边集。应该是可以理解的。理解不了的可以看下面的折叠框。

:::info[解释] 第一条显然,因为到 t 之后不会继续走,只有走到的那一次贡献。

第三条也容易推出来,u 旁边的结点 v 每次被走到,都有 \frac{1}{d_v} 的概率走到 u

第二条类似,只是 s 多了一开始的一次。 :::

似乎高斯消元可以做了?但是复杂度是 O(n^3) 的。

这里用到随机游走题目的一个通用做法:将 f_u 表示为 A_u f_{fa_u}+B_u,其中 fa_uu 的父亲结点。

刚才推出,对于一般的 u\displaystyle f_u=\sum_{(u,v)\in E,v\ne t}\frac{1}{d_v}f_v

也就是说,\displaystyle f_u=\frac{[fa_u\ne t]}{d_{fa_u}}f_{fa_u}+\sum_{(u,v)\in E,v\ne t,v\ne fa}\frac{1}{d_v}f_v

注意到这里的 v 都是 u 的儿子,也就是 fa_v=u

也就是说,f_v=A_v f_u+B_u

代入上面的柿子,\displaystyle f_u=\frac{[fa_u\ne t]}{d_{fa_u}}f_{fa_u}+\sum\frac{1}{d_v}(A_v f_u+B_v)(省略 \sum 下标)。

于是 \displaystyle f_u-\sum \frac{A_v}{d_v}f_u=\frac{[fa_u\ne t]}{d_{fa_u}}f_{fa_u}+\sum\frac{B_v}{d_v}

再把左边的系数除过去就变成了 f_u=A_u f_{fa_u}+B_u 的形式。由于柿子过于抽象我就不写了。

这样下来,我们把所有的 f_u 都表示成了上面的形式。并且注意到根结点是没有父亲的,也就是说我们可以直接知道它的值。一层层往下推就能得到所有结点的 f 值了。

:::success[code]

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define fi first
#define se second
#define inv(x) qpow(x,mod-2)
const int N=2e5+10,mod=998244353;
vector<pair<int,int>>adj[N];
int n,s,t,f[N],a[N],b[N],ta[N],tb[N];
int qpow(int a,int b)
{
    int s=1;
    while(b)
    {
        if(b&1) s=s*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return s;
}
void dfs1(int u,int fa)
{
    for(auto& v:adj[u])
    {
        if(v.fi==fa) continue;
        dfs1(v.fi,u);
        if(v.fi!=t) v.se=inv(adj[v.fi].size());
    }
    if(fa&&fa!=t) ta[u]=inv(adj[fa].size());
    if(u==s) tb[u]=1;
    else if(u==t)
    {
        ta[u]=0,tb[u]=1;
        for(auto& v:adj[u])
        {
            if(v.fi==fa) continue;
            v.se=0;
        }
    }
}
void dfs2(int u,int fa)
{
    a[u]=0,b[u]=tb[u];
    for(auto v:adj[u])
    {
        if(v.fi==fa) continue;
        dfs2(v.fi,u);
        a[u]=(a[u]+a[v.fi]*v.se%mod)%mod;
        b[u]=(b[u]+b[v.fi]*v.se%mod)%mod; 
    }
    int x=inv((1-a[u]+mod)%mod);
    a[u]=ta[u]*x%mod;
    b[u]=b[u]*x%mod;
}
void dfs3(int u,int fa)
{
    f[u]=(a[u]*f[fa]%mod+b[u])%mod;
    for(auto v:adj[u])
    {
        if(v.fi==fa) continue;
        dfs3(v.fi,u);
    }
}
signed main()
{
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    cin>>n>>s>>t;
    for(int i=1;i<n;i++)
    {
        int u,v;
        cin>>u>>v;
        adj[u].push_back({v,0});
        adj[v].push_back({u,0});
    }
    dfs1(1,0);
    dfs2(1,0);
    dfs3(1,0);
    for(int i=1;i<=n;i++) cout<<f[i]<<' ';
    return 0;
}

:::