届かない恋

· · 题解

宝宝容斥。下文记 S=nm

先对序列排序。考虑求出 f_k 表示 bingo 整数大于 a_k 的方案数。那么答案为 \sum f_k(a_{k+1}-a_k)

暴力容斥求出 f_k,记 c_k 表示序列中小于等于 a_k 的数个数,我们把 \le a_k 的位置记作 0 反之记作 1,那么相当于填充 01 矩阵,要求不存在一行或者一列全为 0。容斥枚举 x 行全为 0y 列全为 0。此时钦定为 0 的位置有 t=mx+ny-xy 个,贡献方案数为 \binom{n}{x}\binom{m}{y}\binom{c}{t}t!(nm-t)!,表示先选择一些行列全为 0,然后填数,从 \le a_k 的数里面选择 t 个填入钦定的位置,其他任意填。容斥系数为 (-1)^{x+y}。暴力枚举做到 O(S^2)

观察 f_k 的式子,此处 x,y 不用枚举到 n,m 原因是不存在这种情况:

\begin{aligned} f_k&=\sum_{x=0}^{n-1}\sum_{y=0}^{m-1} \binom{c_k}{t}(-1)^{x+y}\binom{n}{x}\binom{m}{y}t!(nm-t)!\\&=c_k!\sum_{x=0}^{n-1}\sum_{y=0}^{m-1} \frac{1}{(c_k-t)!t!}(-1)^{x+y}\binom{n}{x}\binom{m}{y}t!(nm-t)! \end{aligned}

发现后面这个东西是关于 t 的式子和 \dfrac{1}{(c_k-t)!} 的卷积,所以直接 ntt 做完了,复杂度 O(S\log S)

因为太丑,代码里把 ntt 板子去掉了。

#include <bits/stdc++.h>
#define LL long long
#define ull unsigned long long
#define uint unsigned int
using namespace std;
const int N = 1e6 + 10;
const ull MOD = 998244353;
ull Qpow(ull x, ull k, ull P) {
    ull res = 1, tmp = x;
    for (; k; k >>= 1, tmp = tmp * tmp % P) if (k & 1) res = res * tmp % P;
    return res;
}
int n, m, S, A[N], tot, tmp[N];

ull fact[N], inv[N]; int lim = 1e6;
ull C(int n, int m) { 
    return (n < m ? 0 : fact[n] * inv[m] % MOD * inv[n - m] % MOD); 
}

ull G[N], F[N];

namespace FNTT { /* NTT */ } using FNTT::NTT;

int main() {
    freopen(".in", "r", stdin); freopen(".out", "w", stdout);
    ios::sync_with_stdio(false); cin.tie(0), cout.tie(0);

    fact[0] = 1;
    for (int i = 1; i <= lim; i ++) fact[i] = fact[i - 1] * i % MOD;
    inv[lim] = Qpow(fact[lim], MOD - 2, MOD);
    for (int i = lim; i >= 1; i --) inv[i - 1] = inv[i] * i % MOD;

    int _; cin >> _;
    while (_ --) {
        cin >> n >> m; S = n * m;
        for (int i = 1; i <= S; i ++) cin >> A[i];
        sort(A + 1, A + 1 + S);
        for (int i = 1; i <= S; i ++) tmp[i] = A[i];
        tot = unique(tmp + 1, tmp + 1 + S) - tmp - 1;
        int len = 1; while (len <= (S << 1)) len <<= 1;
        for (int i = 0; i < len; i ++) F[i] = G[i] = 0;
        for (int i = 0; i < S; i ++) F[i] = 0, G[i] = inv[i];
        for (int a = 0; a < n; a ++) for (int b = 0; b < m; b ++) {
            int t = n * b + a * m - a * b;
            ull tmp = inv[t] * fact[t] % MOD 
                * fact[S - t] % MOD * C(n, a) % MOD * C(m, b) % MOD;
            if ((a + b) & 1) F[t] += MOD - tmp;
            else F[t] += tmp;
        }
        for (int i = 0; i < len; i ++) F[i] %= MOD;
        NTT(F, len, 1); NTT(G, len, 1);
        for (int i = 0; i < len; i ++) F[i] = F[i] * G[i] % MOD;
        NTT(F, len, -1);
        ull Ans = fact[S] * tmp[1] % MOD;
        for (int i = 1, c = 0; i < tot; i ++) {
            while (c < S && A[c + 1] <= tmp[i]) ++ c;
            ull res = fact[c] * F[c] % MOD;
            Ans += res * (tmp[i + 1] - tmp[i]) % MOD;
        }
        cout << Ans % MOD << "\n";
    }
    return 0;
}