题解:P5575 [CmdOI2019] 黑白图
看不懂官解在干什么(摊手),但我们可以想出一个
首先,显然有一个暴力 dp 做法,即把当前点
但我们发现本题的
我们记一个连通块为
那么如何统计呢?如果直接统计序列会比较麻烦,我们常见的 dp 都是统计集合或者可重集,所以我们先把序列转化为这两者。
是的,这两个都可以做,而且建议读者两个都思考一下,我觉得这非常有价值。
如果我们选择转化为集合,那么我们要讨论的是多少个序列会被转换为一个大小为
这个问题用容斥可以解决,或者你一眼看出这个就是
那我们可以把式子写一下(记
其实
当然,我保留
如果写出 dp 方程,我们有
- 这个子树的根选了,那先有
f_{u, 0 / 1} \leftarrow p_{u} - 在选了的基础上考虑从儿子转移,有
f_{u, i + j} \leftarrow f'_{u, i} \times f_{v, j} 。这个转移的正确性是显然的:考虑其意义,如果子树根没选,那乘白点的概率;如果子树根选了,那就是合并每两个连通块,概率之积肯定是对的,又因为我们枚举所有i, j ,根据范德蒙德卷积,所以那个组合数也是对的。 - 这个子树的根没选,那有
f_{u, 0} \leftarrow 1 - p_{u} (p_{u} 为这个点是黑点的概率)。这个是要最后加的。
统计答案的时候就对每个点都按照上面期望的式子加起来,记得乘不选父节点的概率(如果有)。
当然,作为久经考验的 OIer,你一眼就看出了这就是生成函数。
我们不妨把它写成这样:
这看着就漂亮多了。当然如果只做树的情况这是没什么用的,在基环树的情况下这样写的优点才会体现出来。
如果我们选择转换为可重集,那么我们要讨论的是多少个序列会被转换为一个可重集,这个问题就显然多了,是多重组合数,或者说我们一般用
因为和集合的情况类似,我们直接给出结论(作为久经考验的 OIer,你其实可以直接猜出下面的式子):
统计答案的时候只需要统计第
总结下,我们树的情况直接 dfs 维护一个多项式即可。暴力多项式乘法,时间复杂度
可是这一题还有基环树的情况,这就麻烦了。
我们考虑基环树上的连通块的形态,大致可以分为三种:
- 把所有环上的点都包括了。这个不需要额外乘不选的概率。
- 除了一个点,包括了所有环上的点。这个要乘那个点不选的概率。
- 其他。这个要乘两个端点不选的概率。
我们设环上的点为
如果我们直接这样一个个乘,那么时间复杂度会来到
所以我们对于这三类连通块,有:
- 直接把所有
G_{w_{i}} 乘起来,然后贡献答案。 - 在上面的基础上把自己的
G_{w_{i}} 除去,然后乘1 - p_{w_{i}} ,贡献答案。 - 把链断开,当作树处理。这样我们的连通块有一个端点是一致的,对于不同的另一个端点,其生成函数都形如若干个
G_{w_{i}} 之积再乘另一个端点的1 - p_{w_{j}} 。所以我们维护这个生成函数,同时维护这些G_{w_{i}} 之积,在转移到下一个点的时候减去G_{w_{i}} 之积乘1 - p_{w_{j}} ,然后再乘G_{w_{i'}} ,最后加上1 - p_{i'} 。统计答案与上面类似。
这样我们就搞定了。时间复杂度
最后一个 trick,就是用拓扑排序处理基环树会好写一点。
集合版本代码:
#include <bits/stdc++.h>
#define em emplace_back
using namespace std;
const int N = 2e5 + 5, P = 998244353, K = 5;
int n, m, k, deg[N];
bool vis[N];
struct pint {
int v; pint(int x = 0): v(x) { }
pint operator + (const pint& x) const { return pint(v + x.v >= P ? v + x.v - P : v + x.v); }
pint operator - (const pint& x) const { return pint(v - x.v < 0 ? v - x.v + P : v - x.v); }
pint operator * (const pint& x) const { return pint(1ll * v * x.v % P); }
pint& operator += (const pint& x) { v = v + x.v >= P ? v + x.v - P : v + x.v; return *this; }
pint& operator -= (const pint& x) { v = v - x.v < 0 ? v - x.v + P : v - x.v; return *this; }
pint& operator *= (const pint& x) { v = 1ll * v * x.v % P; return *this; }
};
const pint A[K + 1][K + 1] = {{}, {0, 1}, {0, 1, 2}, {0, 1, 6, 6}, {0, 1, 14, 36, 24}, {0, 1, 30, 150, 240, 120}};
const pint i100 = 828542813;
using Poly = array<pint, K + 1>;
pint Pow(pint base, int power) {
pint res(1); for (; power; base *= base, power >>= 1) if (power & 1) res *= base;
return res;
}
Poly mul(const Poly& x, const Poly& y) {
Poly z;
for (int i = 0; i <= k; ++i)
for (int j = 0; i + j <= k; ++j) z[i + j] += x[i] * y[j];
return z;
}
Poly div(const Poly& x, const Poly& y) {
Poly z; pint iv = Pow(y[0], P - 2);
for (int i = 0; i <= k; ++i) {
z[i] = x[i]; for (int j = 0; j < i; ++j) z[i] -= z[j] * y[i - j];
z[i] *= iv;
}
return z;
}
pint p[N], q[N], ans;
Poly f[N];
vector<int> e[N], w;
void count(const Poly& x, pint mu = 1) {
for (int i = 1; i <= k; ++i) ans += A[k][i] * x[i] * mu;
}
void dfs0(int u, int fa) {
f[u][0] = f[u][1] = p[u];
for (const auto& v: e[u]) {
if (v == fa) continue;
dfs0(v, u); f[u] = mul(f[u], f[v]);
}
count(f[u], q[fa]);
f[u][0] += q[u];
}
void solvetree() { q[0] = 1; dfs0(1, 0); printf("%d\n", ans.v); }
int getnext(int u) {
for (const auto& v: e[u]) if (!vis[v]) return v;
return -1;
}
void solvecircle() {
queue<int> Q;
for (int i = 1; i <= n; ++i) f[i][0] = f[i][1] = p[i];
for (int i = 1; i <= n; ++i) if (deg[i] == 1) Q.push(i);
while (!Q.empty()) {
int u = Q.front(), fa = getnext(u); Q.pop(); vis[u] = true;
count(f[u], q[fa]);
f[u][0] += q[u]; f[fa] = mul(f[fa], f[u]);
--deg[fa]; if (deg[fa] <= 1) Q.push(fa);
}
for (int i = 1; i <= n; ++i) if (!vis[i]) { w.em(i); break; }
int tmp; vis[w[0]] = true; while ((tmp = getnext(w.back())) != -1) { w.em(tmp); vis[tmp] = true; }
Poly g; g[0] = 1; for (const auto& u: w) g = mul(g, f[u]);
count(g);
for (const auto& u: w) count(div(g, f[u]), q[u]);
Poly h = div(div(g, f[w[0]]), f[w[1]]);
g[0] = q[w[1]]; for (int i = 1; i <= k; ++i) g[i] = 0;
for (int i = 2; i < (int)w.size(); ++i) { g = mul(g, f[w[i]]); g[0] += q[w[i]]; }
for (int i = 0; i < (int)w.size(); ++i) {
int u = w[i], u1 = w[(i + 1) % w.size()], u2 = w[(i + 2) % w.size()];
count(g, q[u]);
for (int j = 0; j <= k; ++j) g[j] -= h[j] * q[u1];
h = mul(div(h, f[u2]), f[u]);
g = mul(g, f[u]); g[0] += q[u];
}
printf("%d\n", ans.v);
}
int main() {
scanf("%d%d%d", &n, &m, &k);
for (int i = 1; i <= n; ++i) { scanf("%d", &p[i].v); p[i] *= i100; q[i] = pint(1) - p[i]; }
for (int i = 1; i <= m; ++i) { int u, v; scanf("%d%d", &u, &v); e[u].em(v); e[v].em(u); ++deg[u]; ++deg[v]; }
if (m == n - 1) solvetree();
else solvecircle();
return 0;
}
可重集版本代码:
#include <bits/stdc++.h>
#define em emplace_back
using namespace std;
const int N = 2e5 + 5, P = 998244353, K = 5;
int n, m, k, deg[N];
bool vis[N];
struct pint {
int v; pint(int x = 0): v(x) { }
pint operator + (const pint& x) const { return pint(v + x.v >= P ? v + x.v - P : v + x.v); }
pint operator - (const pint& x) const { return pint(v - x.v < 0 ? v - x.v + P : v - x.v); }
pint operator * (const pint& x) const { return pint(1ll * v * x.v % P); }
pint& operator += (const pint& x) { v = v + x.v >= P ? v + x.v - P : v + x.v; return *this; }
pint& operator -= (const pint& x) { v = v - x.v < 0 ? v - x.v + P : v - x.v; return *this; }
pint& operator *= (const pint& x) { v = 1ll * v * x.v % P; return *this; }
} C[10][10];
const pint fac[6] = {1, 1, 2, 6, 24, 120};
const pint caf[6] = {1, 1, 499122177, 166374059, 291154603, 856826403};
const pint i100 = 828542813;
using Poly = array<pint, K + 1>;
pint Pow(pint base, int power) {
pint res(1); for (; power; base *= base, power >>= 1) if (power & 1) res *= base;
return res;
}
Poly mul(const Poly& x, const Poly& y) {
Poly z;
for (int i = 0; i <= k; ++i)
for (int j = 0; i + j <= k; ++j) z[i + j] += x[i] * y[j];
return z;
}
Poly div(const Poly& x, const Poly& y) {
Poly z; pint iv = Pow(y[0], P - 2);
for (int i = 0; i <= k; ++i) {
z[i] = x[i]; for (int j = 0; j < i; ++j) z[i] -= z[j] * y[i - j];
z[i] *= iv;
}
return z;
}
pint p[N], q[N], ans;
Poly f[N];
vector<int> e[N], w;
void dfs0(int u, int fa) {
for (const auto& v: e[u]) {
if (v == fa) continue;
dfs0(v, u); f[u] = mul(f[u], f[v]);
}
ans += f[u][k] * q[fa];
f[u][0] += q[u];
}
void solvetree() { q[0] = 1; dfs0(1, 0); printf("%d\n", (ans * fac[k]).v); }
int getnext(int u) {
for (const auto& v: e[u]) if (!vis[v]) return v;
return 0;
}
void solvecircle() {
queue<int> Q;
for (int i = 1; i <= n; ++i) if (deg[i] == 1) Q.push(i);
while (!Q.empty()) {
int u = Q.front(), fa = getnext(u); Q.pop(); vis[u] = true;
ans += f[u][k] * q[fa];
f[u][0] += q[u]; f[fa] = mul(f[fa], f[u]);
--deg[fa]; if (deg[fa] <= 1) Q.push(fa);
}
for (int i = 1; i <= n; ++i) if (!vis[i]) { w.em(i); break; }
int tmp = getnext(w.back()); vis[w[0]] = true;
while (tmp) { w.em(tmp); vis[tmp] = true; tmp = getnext(w.back()); }
Poly g; g[0] = 1; for (const auto& u: w) g = mul(g, f[u]);
ans += g[k]; for (const auto& u: w) ans += div(g, f[u])[k] * q[u];
Poly h = div(div(g, f[w[0]]), f[w[1]]);
g[0] = q[w[1]]; for (int i = 1; i <= k; ++i) g[i] = 0;
for (int i = 2; i < (int)w.size(); ++i) { g = mul(g, f[w[i]]); g[0] += q[w[i]]; }
for (int i = 0; i < (int)w.size(); ++i) {
int u = w[i], u1 = w[(i + 1) % w.size()], u2 = w[(i + 2) % w.size()];
ans += g[k] * q[u];
for (int j = 0; j <= k; ++j) g[j] -= h[j] * q[u1];
h = mul(div(h, f[u2]), f[u]);
g = mul(g, f[u]); g[0] += q[u];
}
printf("%d\n", (ans * fac[k]).v);
}
int main() {
scanf("%d%d%d", &n, &m, &k);
for (int i = 1; i <= n; ++i) { scanf("%d", &p[i].v); p[i] *= i100; q[i] = pint(1) - p[i]; }
for (int i = 1; i <= m; ++i) { int u, v; scanf("%d%d", &u, &v); e[u].em(v); e[v].em(u); ++deg[u]; ++deg[v]; }
for (int i = 1; i <= n; ++i)
for (int j = 0; j <= k; ++j) f[i][j] = p[i] * caf[j];
if (m == n - 1) solvetree();
else solvecircle();
return 0;
}