题解:AT_arc118_e [ARC118E] Avoid Permutations

· · 题解

我们先说明一个定理

\sum_{T \subseteq S} (-1)^{|T|} = [S = \emptyset]

我们记 S 表示路径上的点构成的集合,T 表示一个合法的障碍集合。答案为

\begin{aligned} \sum_{S}\sum_{T} [S \cap T = \emptyset] &= \sum_{S} \sum_{T} \sum_{S' \subseteq S \cap T} (-1)^{|S'|}\\ &= \sum_{S} \sum_{S' \subseteq S} (-1)^{|S'|} \sum_{T \supseteq S'} 1\\ &= \sum_{S} \sum_{S' \subseteq S} (-1)^{|S'|} \sum_{T \supseteq S'} 1\\ \end{aligned}

P 表示已经确定的障碍集合。显然有 T \supseteq P,我们又要求 T \supseteq S',所以 T \supseteq S' \cup P。我们记 k = |S' \cup P|,那么有 kk 列的点是确定的,剩余 n - kn - k 列的点是任意选择的,共有 (n - k)! 种方案。从而 \sum_{T \supseteq S'} 1 = (n - k)!,从而答案

\begin{aligned} \sum_{S} \sum_{S' \subseteq S} (-1)^{|S'|} \sum_{T \supseteq S'} 1 &= \sum_{S} \sum_{S' \subseteq S} (-1)^{|S'|} (n - |S' \cup P|)!\\ \end{aligned}

现在的式子是容易 DP 的。设 f_{x, y, i, p, q} 表示现在走到了 (x, y)|S' \cup P| = i,现在是否已经在第 x 行选择障碍的情况为 p \in \{0, 1\},是否已经在第 y 列选择障碍的情况为 q \in \{0, 1\} 的方案数。有初始化 f_{0, 0, |P|, 0, 0} = 1

我们首先有直接向下一个位置走,不选择下一个位置作为障碍的转移。

f_{x, y, i, p, q} \to f_{x + 1, y, i, 0, q}\\ f_{x, y, i, p, q} \to f_{x, y + 1, i, p, 0}\\

考虑选择下一个位置为障碍的转移。首先,|S'| 会增加 1,所以系数乘上 -1。然后,显然 p = q = 1。最后,如果选定的点在 \{a_i\} 中已经被要求,那么 |S' \cup P| 不变,即 i 不变;否则 i 增加 1

注意并非任意情况下都可以选择下一个位置作为状态。我们有转移

-f_{x, y, i, p, q} \to f_{x + 1, y, i + [a_{x + 1} \neq y], 1, 1}(q = 0, x + 1 \in [1, n], y \in [1, n], (a_{x + 1} = y \vee a_{x + 1} = -1 \wedge (\forall k \in [1, n])(a_k \neq y)))\\ -f_{x, y, i, p, q} \to f_{x, y + 1, i + [a_{x} \neq y + 1], 1, 1}(p = 0, x \in [1, n], y + 1 \in [1, n], (a_{x} = y + 1 \vee a_{x} = -1 \wedge (\forall k \in [1, n])(a_k \neq y + 1)))\\

计算答案时考虑枚举 i = |S' \cup P|,答案即为 \sum_{i = |P|}^{n} (n - i)! \times f_{n + 1, n + 1, i, 0, 0}

时间复杂度 \mathcal{O}(n^3)

Code:

#include<bits/stdc++.h>
#define mem(a, v) memset(a, v, sizeof(a));

using namespace std;

const int maxn = 200 + 10, mod = 998244353;

int n, res = 0;
int a[maxn], fac[maxn], f[maxn][maxn][maxn][2][2];
bitset<maxn> vis;

template<typename Tp_x, typename Tp_y>
inline void mod_add(Tp_x &x, Tp_y y){
    x += y, x >= mod ? x -= mod : x;
}

int main(){
    scanf("%d", &n);
    int cnt = 0;
    fac[0] = 1;
    for (int i = 1; i <= n; i++){
        fac[i] = (long long)fac[i - 1] * i % mod;
        scanf("%d", &a[i]);
        if (~a[i]){
            cnt++, vis.set(a[i]);
        }
    }
    f[0][0][cnt][0][0] = 1;
    for (int i = 0; i <= n + 1; i++){
        for (int j = 0; j <= n + 1; j++){
            for (int k = cnt; k <= n; k++){
                for (int p = 0; p < 2; p++){
                    for (int q = 0; q < 2; q++){
                        if (i <= n){
                            mod_add(f[i + 1][j][k][0][q], f[i][j][k][p][q]);
                            if (!q && i + 1 >= 1 && i + 1 <= n && j >= 1 && j <= n && (a[i + 1] == j || !~a[i + 1] && !vis[j])){
                                mod_add(f[i + 1][j][k + (a[i + 1] != j)][1][1], mod - f[i][j][k][p][q]);
                            }
                        }
                        if (j <= n){
                            mod_add(f[i][j + 1][k][p][0], f[i][j][k][p][q]);
                            if (!p && i >= 1 && i <= n && j + 1 >= 1 && j + 1 <= n && (a[i] == j + 1 || !~a[i] && !vis[j + 1])){
                                mod_add(f[i][j + 1][k + (a[i] != j + 1)][1][1], mod - f[i][j][k][p][q]);
                            }
                        }
                    }
                }
            }
        }
    }
    for (int i = cnt; i <= n; i++){
        mod_add(res, (long long)fac[n - i] * f[n + 1][n + 1][i][0][0] % mod);
    }
    printf("%d", res);

return 0;
}