题解:P5575 [CmdOI2019] 黑白图

· · 题解

看不懂官解在干什么(摊手),但我们可以想出一个 O(n k^{2}) 的解法(虽然运行时长似乎没有优势),而且更好写,故记录如下。

首先,显然有一个暴力 dp 做法,即把当前点 u 的连通块大小纳入状态,但这样单是状态数就爆炸了。

但我们发现本题的 k 特别小,只有 5,所以我们考虑统计 k 相关的东西。

我们记一个连通块为 G,那就是求 |G|^{k},其组合意义是生成一个序列 G^{k},那一个自然的想法是统计这个序列,因为它只有 k 的长度。

那么如何统计呢?如果直接统计序列会比较麻烦,我们常见的 dp 都是统计集合或者可重集,所以我们先把序列转化为这两者。

是的,这两个都可以做,而且建议读者两个都思考一下,我觉得这非常有价值。

如果我们选择转化为集合,那么我们要讨论的是多少个序列会被转换为一个大小为 i 的集合,具体而言,有多少个序列 [i]^{k} 满足 i 个元素全出现过?

这个问题用容斥可以解决,或者你一眼看出这个就是 i! S(k, i),其中 S(k, i) 为第二类斯特林数(不会吧,不会笔者当时没看出来吧)。

那我们可以把式子写一下(记 U 为一个染色方案下所有连通块的集合,P_{U} 为其出现概率):

\begin{align} E \left[ \sum_{G \in U} |G|^{k} \right] &= \sum_{U} P_{U} \sum_{G \in U} |G|^{k} \\ &= \sum_{u} P_{u} \sum_{G \in U} \sum_{1 \le i \le k} \sum_{\{ u_{1}, u_{2}, \dots, u_{i} \} \subseteq G} i! S(k, i) \\ &= \sum_{1 \le i \le k} i! S(k, i) \sum_{G} \left( \sum_{\{u_{1}, u_{2}, \dots, u_{i} \} \subseteq G} 1 \right) \left( \sum_{G \in U} P_{U} \right) \\ &= \sum_{1 \le i \le k} i! S(k, i) \sum_{\{u_{1}, u_{2}, \dots, u_{i} \}} \sum_{\{u_{1}, u_{2}, \dots, u_{i} \} \subseteq G} \sum_{G \in U} P_{U} \end{align}

其实 \sum_{G \in U} P_{U} 就是 G 出现的概率,所以我们只需要统计所有大小等于 i \le k 的点集处于一个连通块中的概率即可。

当然,我保留 (3) 式是因为之后我们具体计算的时候更多考虑这个式子的组合意义(以防你不知道,它中间那个括号实际上是 \binom{|G|}{i})。

如果写出 dp 方程,我们有 f_{u, i} 表示 u 子树中,所有包含 u 的连通块的出现概率乘它们的 \binom{|G|}{i} 之和,转移有:

统计答案的时候就对每个点都按照上面期望的式子加起来,记得乘不选父节点的概率(如果有)。

当然,作为久经考验的 OIer,你一眼就看出了这就是生成函数。

我们不妨把它写成这样:

F_{u}(x) = (1 - p_{u}) + p_{u} (1 + x) \prod_{v} F_{v}(x)

这看着就漂亮多了。当然如果只做树的情况这是没什么用的,在基环树的情况下这样写的优点才会体现出来。

如果我们选择转换为可重集,那么我们要讨论的是多少个序列会被转换为一个可重集,这个问题就显然多了,是多重组合数,或者说我们一般用 \exp 生成函数处理。

因为和集合的情况类似,我们直接给出结论(作为久经考验的 OIer,你其实可以直接猜出下面的式子):

F_{u}(x) = (1 - p_{u}) + p_{u} \left( \sum_{0 \le i \le k} \frac{x^{i}}{i!} \right) \prod_{v} F_{v}(x)

统计答案的时候只需要统计第 \left[ \frac{x^{k}}{k!} \right] 项即可,同样要乘父节点不选的概率。

总结下,我们树的情况直接 dfs 维护一个多项式即可。暴力多项式乘法,时间复杂度 O(n k^{2})(当然你非要写个 NTT 然后宣称时间复杂度来到 O(n k \log k) 我也没法说啥)。

可是这一题还有基环树的情况,这就麻烦了。

我们考虑基环树上的连通块的形态,大致可以分为三种:

我们设环上的点为 w_{1, 2, \dots, z},如果一个点 w_{i} 被包括在连通块内,那这个连通块的生成函数应乘上 G_{w_{i}} = F_{w_{i}}(x) - (1 - p_{w_{i}}),这是显然的。

如果我们直接这样一个个乘,那么时间复杂度会来到 O(n^{2} k^{2})。但这是多项式,我们是有除法的。

所以我们对于这三类连通块,有:

这样我们就搞定了。时间复杂度 O(n k^{2})

最后一个 trick,就是用拓扑排序处理基环树会好写一点。

集合版本代码:

#include <bits/stdc++.h>
#define em emplace_back
using namespace std;
const int N = 2e5 + 5, P = 998244353, K = 5;
int n, m, k, deg[N];
bool vis[N];

struct pint {
    int v; pint(int x = 0): v(x) { }
    pint operator + (const pint& x) const { return pint(v + x.v >= P ? v + x.v - P : v + x.v); }
    pint operator - (const pint& x) const { return pint(v - x.v < 0 ? v - x.v + P : v - x.v); }
    pint operator * (const pint& x) const { return pint(1ll * v * x.v % P); }
    pint& operator += (const pint& x) { v = v + x.v >= P ? v + x.v - P : v + x.v; return *this; }
    pint& operator -= (const pint& x) { v = v - x.v < 0 ? v - x.v + P : v - x.v; return *this; }
    pint& operator *= (const pint& x) { v = 1ll * v * x.v % P; return *this; }
};
const pint A[K + 1][K + 1] = {{}, {0, 1}, {0, 1, 2}, {0, 1, 6, 6}, {0, 1, 14, 36, 24}, {0, 1, 30, 150, 240, 120}};
const pint i100 = 828542813;
using Poly = array<pint, K + 1>;

pint Pow(pint base, int power) {
    pint res(1); for (; power; base *= base, power >>= 1) if (power & 1) res *= base;
    return res;
}

Poly mul(const Poly& x, const Poly& y) {
    Poly z;
    for (int i = 0; i <= k; ++i)
        for (int j = 0; i + j <= k; ++j) z[i + j] += x[i] * y[j];
    return z;
}

Poly div(const Poly& x, const Poly& y) {
    Poly z; pint iv = Pow(y[0], P - 2);
    for (int i = 0; i <= k; ++i) {
        z[i] = x[i]; for (int j = 0; j < i; ++j) z[i] -= z[j] * y[i - j];
        z[i] *= iv;
    }
    return z;
}

pint p[N], q[N], ans;
Poly f[N];
vector<int> e[N], w;

void count(const Poly& x, pint mu = 1) {
    for (int i = 1; i <= k; ++i) ans += A[k][i] * x[i] * mu;
}

void dfs0(int u, int fa) {
    f[u][0] = f[u][1] = p[u];
    for (const auto& v: e[u]) {
        if (v == fa) continue;
        dfs0(v, u); f[u] = mul(f[u], f[v]);
    }
    count(f[u], q[fa]);
    f[u][0] += q[u];
}
void solvetree() { q[0] = 1; dfs0(1, 0); printf("%d\n", ans.v); }

int getnext(int u) {
    for (const auto& v: e[u]) if (!vis[v]) return v;
    return -1;
}
void solvecircle() {
    queue<int> Q;
    for (int i = 1; i <= n; ++i) f[i][0] = f[i][1] = p[i];
    for (int i = 1; i <= n; ++i) if (deg[i] == 1) Q.push(i);
    while (!Q.empty()) {
        int u = Q.front(), fa = getnext(u); Q.pop(); vis[u] = true;
        count(f[u], q[fa]);
        f[u][0] += q[u]; f[fa] = mul(f[fa], f[u]);
        --deg[fa]; if (deg[fa] <= 1) Q.push(fa);
    }

    for (int i = 1; i <= n; ++i) if (!vis[i]) { w.em(i); break; }
    int tmp; vis[w[0]] = true; while ((tmp = getnext(w.back())) != -1) { w.em(tmp); vis[tmp] = true; }

    Poly g; g[0] = 1; for (const auto& u: w) g = mul(g, f[u]);
    count(g);
    for (const auto& u: w) count(div(g, f[u]), q[u]);

    Poly h = div(div(g, f[w[0]]), f[w[1]]);
    g[0] = q[w[1]]; for (int i = 1; i <= k; ++i) g[i] = 0;
    for (int i = 2; i < (int)w.size(); ++i) { g = mul(g, f[w[i]]); g[0] += q[w[i]]; }
    for (int i = 0; i < (int)w.size(); ++i) {
        int u = w[i], u1 = w[(i + 1) % w.size()], u2 = w[(i + 2) % w.size()];
        count(g, q[u]);
        for (int j = 0; j <= k; ++j) g[j] -= h[j] * q[u1];
        h = mul(div(h, f[u2]), f[u]);
        g = mul(g, f[u]); g[0] += q[u];
    }
    printf("%d\n", ans.v);
}

int main() {
    scanf("%d%d%d", &n, &m, &k);
    for (int i = 1; i <= n; ++i) { scanf("%d", &p[i].v); p[i] *= i100; q[i] = pint(1) - p[i]; }
    for (int i = 1; i <= m; ++i) { int u, v; scanf("%d%d", &u, &v); e[u].em(v); e[v].em(u); ++deg[u]; ++deg[v]; }
    if (m == n - 1) solvetree();
    else solvecircle();
    return 0;
}

可重集版本代码:

#include <bits/stdc++.h>
#define em emplace_back
using namespace std;
const int N = 2e5 + 5, P = 998244353, K = 5;
int n, m, k, deg[N];
bool vis[N];

struct pint {
    int v; pint(int x = 0): v(x) { }
    pint operator + (const pint& x) const { return pint(v + x.v >= P ? v + x.v - P : v + x.v); }
    pint operator - (const pint& x) const { return pint(v - x.v < 0 ? v - x.v + P : v - x.v); }
    pint operator * (const pint& x) const { return pint(1ll * v * x.v % P); }
    pint& operator += (const pint& x) { v = v + x.v >= P ? v + x.v - P : v + x.v; return *this; }
    pint& operator -= (const pint& x) { v = v - x.v < 0 ? v - x.v + P : v - x.v; return *this; }
    pint& operator *= (const pint& x) { v = 1ll * v * x.v % P; return *this; }
} C[10][10];
const pint fac[6] = {1, 1, 2, 6, 24, 120};
const pint caf[6] = {1, 1, 499122177, 166374059, 291154603, 856826403};
const pint i100 = 828542813;
using Poly = array<pint, K + 1>;

pint Pow(pint base, int power) {
    pint res(1); for (; power; base *= base, power >>= 1) if (power & 1) res *= base;
    return res;
}

Poly mul(const Poly& x, const Poly& y) {
    Poly z;
    for (int i = 0; i <= k; ++i)
        for (int j = 0; i + j <= k; ++j) z[i + j] += x[i] * y[j];
    return z;
}

Poly div(const Poly& x, const Poly& y) {
    Poly z; pint iv = Pow(y[0], P - 2);
    for (int i = 0; i <= k; ++i) {
        z[i] = x[i]; for (int j = 0; j < i; ++j) z[i] -= z[j] * y[i - j];
        z[i] *= iv;
    }
    return z;
}

pint p[N], q[N], ans;
Poly f[N];
vector<int> e[N], w;

void dfs0(int u, int fa) {
    for (const auto& v: e[u]) {
        if (v == fa) continue;
        dfs0(v, u); f[u] = mul(f[u], f[v]);
    }
    ans += f[u][k] * q[fa];
    f[u][0] += q[u];
}
void solvetree() { q[0] = 1; dfs0(1, 0); printf("%d\n", (ans * fac[k]).v); }

int getnext(int u) {
    for (const auto& v: e[u]) if (!vis[v]) return v;
    return 0;
}
void solvecircle() {
    queue<int> Q;
    for (int i = 1; i <= n; ++i) if (deg[i] == 1) Q.push(i);
    while (!Q.empty()) {
        int u = Q.front(), fa = getnext(u); Q.pop(); vis[u] = true;
        ans += f[u][k] * q[fa];
        f[u][0] += q[u]; f[fa] = mul(f[fa], f[u]);
        --deg[fa]; if (deg[fa] <= 1) Q.push(fa);
    }

    for (int i = 1; i <= n; ++i) if (!vis[i]) { w.em(i); break; }
    int tmp = getnext(w.back()); vis[w[0]] = true;
    while (tmp) { w.em(tmp); vis[tmp] = true; tmp = getnext(w.back()); }

    Poly g; g[0] = 1; for (const auto& u: w) g = mul(g, f[u]);
    ans += g[k]; for (const auto& u: w) ans += div(g, f[u])[k] * q[u];

    Poly h = div(div(g, f[w[0]]), f[w[1]]);
    g[0] = q[w[1]]; for (int i = 1; i <= k; ++i) g[i] = 0;
    for (int i = 2; i < (int)w.size(); ++i) { g = mul(g, f[w[i]]); g[0] += q[w[i]]; }
    for (int i = 0; i < (int)w.size(); ++i) {
        int u = w[i], u1 = w[(i + 1) % w.size()], u2 = w[(i + 2) % w.size()];
        ans += g[k] * q[u];
        for (int j = 0; j <= k; ++j) g[j] -= h[j] * q[u1];
        h = mul(div(h, f[u2]), f[u]);
        g = mul(g, f[u]); g[0] += q[u];
    }
    printf("%d\n", (ans * fac[k]).v);
}

int main() {
    scanf("%d%d%d", &n, &m, &k);
    for (int i = 1; i <= n; ++i) { scanf("%d", &p[i].v); p[i] *= i100; q[i] = pint(1) - p[i]; }
    for (int i = 1; i <= m; ++i) { int u, v; scanf("%d%d", &u, &v); e[u].em(v); e[v].em(u); ++deg[u]; ++deg[v]; }
    for (int i = 1; i <= n; ++i)
        for (int j = 0; j <= k; ++j) f[i][j] = p[i] * caf[j];
    if (m == n - 1) solvetree();
    else solvecircle();
    return 0;
}