【PKUWC2018】随机游走

nekko

2018-12-25 20:45:55

Personal

### 题目描述 给定一棵 $n$ 个结点的树,你从点 $x$ 出发,每次等概率随机选择一条与所在点相邻的边走过去。 有 $Q$ 次询问,每次询问给定一个集合 $S$,求如果从 $x$ 出发一直随机游走,直到点集 $S$ 中所有点都至少经过一次的话,期望游走几步。 特别地,点 $x$(即起点)视为一开始就被经过了一次。 答案对 $998244353 $ 取模。 $1\leq n\leq 18$ $1\leq Q\leq 5000$ $1\leq k\leq n$ ### 题解 又是一个 `min-max 容斥` 的板子…… 显然有: $E(\max(S)) = \sum_{T \subseteq S} (-1)^{\mid T \mid + 1}E(\min(T))$ 考虑 $E(\min(T))$ 是什么意思呢……就是从 $x$ 出发,经过 $T$ 中的点各至少一次的期望步数 于是可以枚举 $S$,设 $f_u$ 表示从 $u$ 出发,经过 $S$ 中的点各至少一次的期望步数 于是有: $$f_u=1+\frac{1}{deg_u}\sum_{u \to v} f_v$$ ~~然后就可以高斯消元了,由于每次的枚举集合中的有效点总共不多,卡卡常就可以过了~~ 注意一下这道题的特殊性,是一棵 **树**,于是可以把 $f_u$ 表示成 $a_u f_{fa_u}+b_u$ 的形式 之后直接随便搞搞就搞出来了…… 设 $g_{T}$ 表示钦定的集合为 $T$ 时,从 $x$ 出发,经过 $T$ 中的所有点至少一次的期望步数 对于查询的一个询问 $S$ 来说,答案就是: $$\sum_{T \subseteq S} (-1)^{\mid T \mid + 1} g_{T}$$ 于是可以先 $FMT$ 一下,处理出后面那个的子集和,查询就可以 $O(1)$ 了 总的时间复杂度为 $O(n 2^n \log P+q)$ 带上一个 $\log P$ 是因为在预处理的时候要用到一个求逆元 ### 代码 ``` cpp #include <bits/stdc++.h> using namespace std; typedef long long ll; const int N = 20, mod = 998244353; vector<int> g[N]; int n, q, x; ll ans[1 << N], deginv[N]; ll pw(ll a, ll b) { ll r = 1; for( ; b ; b >>= 1, a = a * a % mod) if(b & 1) r = r * a % mod; return r; } struct T { ll a, b; // f[u] = a*f[fa]+b T(ll a = 0, ll b = 0): a(a), b(b) {} T operator + (T t) { return (T) { (a + t.a) % mod, (b + t.b) % mod }; } } f[1 << N][N]; void dfs(int u, int fa, T *f, int S) { if(S & (1 << (u - 1))) return ; T sum; for(int v: g[u]) { if(v == fa) continue; dfs(v, u, f, S); sum = sum + f[v]; } ll tmp = pw(1 - deginv[u] * sum.a % mod, mod - 2); f[u].a = deginv[u] * tmp % mod; f[u].b = (1 + sum.b * deginv[u] % mod) * tmp % mod; } int cnt[1 << 20]; int main() { ios :: sync_with_stdio(0); cin >> n >> q >> x; for(int i = 1, u, v ; i < n ; ++ i) cin >> u >> v, g[u].push_back(v), g[v].push_back(u); for(int i = 1 ; i <= n ; ++ i) deginv[i] = pw(g[i].size(), mod - 2); for(int s = 0 ; s < (1 << n) ; ++ s) cnt[s] = cnt[s >> 1] + (s & 1); for(int s = 1 ; s < (1 << n) ; ++ s) dfs(x, 0, f[s], s), ans[s] = (cnt[s] & 1 ? 1 : -1) * f[s][x].b % mod; for(int i = 1 ; i <= n ; ++ i) for(int s = 0 ; s < (1 << n) ; ++ s) if(s & (1 << (i - 1))) (ans[s] += ans[s - (1 << (i - 1))]) %= mod; for(int i = 1 ; i <= q ; ++ i) { int k, x = 0, y; cin >> k; for(int j = 1 ; j <= k ; ++ j) cin >> y, x |= 1 << (y - 1); cout << (ans[x] % mod + mod) % mod << endl; } } ```