【PKUWC2018】随机游走
nekko
2018-12-25 20:45:55
### 题目描述
给定一棵 $n$ 个结点的树,你从点 $x$ 出发,每次等概率随机选择一条与所在点相邻的边走过去。
有 $Q$ 次询问,每次询问给定一个集合 $S$,求如果从 $x$ 出发一直随机游走,直到点集 $S$ 中所有点都至少经过一次的话,期望游走几步。
特别地,点 $x$(即起点)视为一开始就被经过了一次。
答案对 $998244353 $ 取模。
$1\leq n\leq 18$
$1\leq Q\leq 5000$
$1\leq k\leq n$
### 题解
又是一个 `min-max 容斥` 的板子……
显然有:
$E(\max(S)) = \sum_{T \subseteq S} (-1)^{\mid T \mid + 1}E(\min(T))$
考虑 $E(\min(T))$ 是什么意思呢……就是从 $x$ 出发,经过 $T$ 中的点各至少一次的期望步数
于是可以枚举 $S$,设 $f_u$ 表示从 $u$ 出发,经过 $S$ 中的点各至少一次的期望步数
于是有:
$$f_u=1+\frac{1}{deg_u}\sum_{u \to v} f_v$$
~~然后就可以高斯消元了,由于每次的枚举集合中的有效点总共不多,卡卡常就可以过了~~
注意一下这道题的特殊性,是一棵 **树**,于是可以把 $f_u$ 表示成 $a_u f_{fa_u}+b_u$ 的形式
之后直接随便搞搞就搞出来了……
设 $g_{T}$ 表示钦定的集合为 $T$ 时,从 $x$ 出发,经过 $T$ 中的所有点至少一次的期望步数
对于查询的一个询问 $S$ 来说,答案就是:
$$\sum_{T \subseteq S} (-1)^{\mid T \mid + 1} g_{T}$$
于是可以先 $FMT$ 一下,处理出后面那个的子集和,查询就可以 $O(1)$ 了
总的时间复杂度为 $O(n 2^n \log P+q)$
带上一个 $\log P$ 是因为在预处理的时候要用到一个求逆元
### 代码
``` cpp
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 20, mod = 998244353;
vector<int> g[N];
int n, q, x; ll ans[1 << N], deginv[N];
ll pw(ll a, ll b) {
ll r = 1;
for( ; b ; b >>= 1, a = a * a % mod)
if(b & 1)
r = r * a % mod;
return r;
}
struct T {
ll a, b;
// f[u] = a*f[fa]+b
T(ll a = 0, ll b = 0): a(a), b(b) {}
T operator + (T t) {
return (T) { (a + t.a) % mod, (b + t.b) % mod };
}
} f[1 << N][N];
void dfs(int u, int fa, T *f, int S) {
if(S & (1 << (u - 1))) return ;
T sum;
for(int v: g[u]) {
if(v == fa) continue;
dfs(v, u, f, S);
sum = sum + f[v];
}
ll tmp = pw(1 - deginv[u] * sum.a % mod, mod - 2);
f[u].a = deginv[u] * tmp % mod;
f[u].b = (1 + sum.b * deginv[u] % mod) * tmp % mod;
}
int cnt[1 << 20];
int main() {
ios :: sync_with_stdio(0);
cin >> n >> q >> x;
for(int i = 1, u, v ; i < n ; ++ i)
cin >> u >> v,
g[u].push_back(v),
g[v].push_back(u);
for(int i = 1 ; i <= n ; ++ i)
deginv[i] = pw(g[i].size(), mod - 2);
for(int s = 0 ; s < (1 << n) ; ++ s)
cnt[s] = cnt[s >> 1] + (s & 1);
for(int s = 1 ; s < (1 << n) ; ++ s)
dfs(x, 0, f[s], s),
ans[s] = (cnt[s] & 1 ? 1 : -1) * f[s][x].b % mod;
for(int i = 1 ; i <= n ; ++ i)
for(int s = 0 ; s < (1 << n) ; ++ s)
if(s & (1 << (i - 1)))
(ans[s] += ans[s - (1 << (i - 1))]) %= mod;
for(int i = 1 ; i <= q ; ++ i) {
int k, x = 0, y; cin >> k;
for(int j = 1 ; j <= k ; ++ j)
cin >> y,
x |= 1 << (y - 1);
cout << (ans[x] % mod + mod) % mod << endl;
}
}
```