ABC259Ex

· · 题解

思路:

由于只要求起点和终点的颜色相同,那么可以枚举颜色。

对于每种颜色,分别考虑。

设当前颜色为 c

考虑两种暴力:

然后我们发现当当前颜色不超过 n 个时,暴力 2 的时间总复杂度仅为 \mathcal{O(n^3)},而超过 n 个时,由于最多有 n 个颜色的点数 \ge n(要不然总点数就超过 n^2 了),那么用暴力 1 总时间复杂度也为 \mathcal{O(n^3)},所以两个结合起来就可以过啦。

总时间复杂度为 \mathcal{O(n^3)}

Code:

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

#define int ll

const int N = 1010, mod = 998244353;
int fac[N], inv[N];

int n;
int a[410][410];
int f[410][410];
vector<pair<int, int>> e[160010];

int C(int n, int m) {
    if (n < m) return 0;
    if (n < 0 || m < 0) return 0;
    return fac[n] * inv[m] % mod * inv[n - m] % mod;
}

int qmi(int a, int k, int p) {
    int res = 1;
    while (k){
        if (k & 1) res = (ll)res * a % p;
        a = (ll)a * a % p;
        k >>= 1;
    }
    return res;
}

signed main() {
    fac[0] = inv[0] = 1;
    for (int i = 1; i < N; ++i) {
        fac[i] = (ll)fac[i - 1] * i % mod;
        inv[i] = (ll)inv[i - 1] * qmi(i, mod - 2, mod) % mod;
    }
    scanf("%lld", &n);
    for (int i = 1; i <= n; ++i) {
        for (int j = 1; j <= n; ++j) {
            scanf("%lld", &a[i][j]);
            e[a[i][j]].push_back({i, j});
        }
    }

    ll ans = 0;

    for (int i = 1; i <= n * n; ++i) {
        if (!e[i].size()) continue;
        if (e[i].size() <= n) {
            for (auto [X1, Y1] : e[i])
                for (auto [X2, Y2] : e[i]) {
                    ans = (ans + C(X2 - X1 + Y2 - Y1, X2 - X1)) % mod;
                }
        } else {
            for (int j = 1; j <= n; ++j)
                for (int k = 1; k <= n; ++k) {
                    f[j][k] = (a[j][k] == i);
                    f[j][k] = (f[j][k] + f[j][k - 1]) % mod;
                    f[j][k] = (f[j][k] + f[j - 1][k]) % mod;
                    ans += f[j][k] * (a[j][k] == i);
                    ans %= mod;
                }
        }
    }

    printf("%lld\n", ans);
    return 0;
}