P14928 [北大集训 2025] 深红

· · 题解

尝试不用群论语言复述一下这题的做法(?,感谢 ChatGPT 老师的指导。

考虑第二问。由特殊性质启发,我们声称,答案为对所有 x \in [0, n), y \in [0, m),满足对所有 i, j 都有 p(A_{i, j}) = A_{(i + x) \bmod n, (j + y) \bmod m} 的矩阵 A 数量之和,再除以 nm

证明考虑设一个相似类有 t 个矩阵,那么类中一个矩阵有 \frac{nm}{t} 个平移量 (x, y) 能把 A 变成 f(A),又因为类中有 t 个矩阵,所以一个类恰好被统计 t \times \frac{nm}{t} = nm 次。

考虑如何对单个平移量 (x, y),求对所有 i, j 满足 p(A_{i, j}) = A_{(i + x) \bmod n, (j + y) \bmod m} 的矩阵 A 数量。设排列有 c_i 个元素满足其所在环长是 i 的因数。那么若在 (i, j) \to ((i + x) \bmod n, (j + y) \bmod m) 连有向边,会连出一些大小为 \text{size} = \operatorname{lcm}(\frac{n}{\gcd(n, x)}, \frac{m}{\gcd(m, y)}) 的环,总共有 \frac{nm}{\text{size}} 个。贡献即为 c_{\text{size}}\frac{nm}{\text{size}} 次方。只枚举 \gcd(n, x)\gcd(m, y),欧拉函数算一下对应的 x, y 的个数,O(nm) 预处理后可以做到 O(d(n) d(m) \log nm) 的复杂度。

考虑第一问。现在设一个相似类有 t 个矩阵,我们希望所有相似类的贡献是对应的 t 而不是 1

我们考虑对一个平移量的集合 S \subseteq [0, n) \times [0, m),算有多少个类,使得类中矩阵满足,经过 (x, y) 的平移变换后不变的平移量集合恰好是 S。定义这样的集合为这个类的不动集合。

考虑 S 需要满足 \forall (x_1, y_1), (x_2, y_2) \in S, ((x_1 + x_2) \bmod n, (y_1 + y_2) \bmod m) \in S 即对于加法封闭,且 \forall (x, y) \in S, ((-x) \bmod n, (-y) \bmod m) \in S,所以感受到可能的 S 不是很多。

(n', 0) 为满足第一维非零且最小、第二维为零的 S 中的平移量,同理设 (0, m') 为满足第一维为零、第二维非零且最小的 S 中平移量。整个矩阵满足 A_{i, j} = A_{(i + n') \bmod n, j} = A_{i, (j + m') \bmod m},也就是有行周期 n' 和列周期 m',所以可以只考虑 n' \times m' 的子矩阵。

(x, y) 为满足 x \in [1, n'), y \in [1, m')(x, y) \in Sx 最小的平移量(有可能不存在),称这个为 S 的斜向分量。设 s = \gcd(n', m'),发现 (x, y)(x, y) 不存在可以由一个整数 0 \le q < s 唯一确定。 设 g = \gcd(q, s),那么 (x, y) = (\frac{n'}{s} g, \frac{m'}{s} q)(若 q = 0 说明 (x, y) = (n', 0) 表示斜向分量不存在)。也就是把 n' \times m' 矩阵分成 s \times s 个块,每块大小均为 \frac{n'}{s} \times \frac{m'}{s}。每次向右循环地移动 q 个整块的同时向下移动 g 个整块,移动 \frac{s}{g} 次恰好回到第一行且回到原点。

所以我们进而可以证明,所有可能的 S 和满足 n' \mid n, m' \mid m, 0 \le q < \gcd(n', m') 的三元组 (n', m', q) 构成双射。我们有三个平移量 (n', 0), (0, m'), (x, y) = (\frac{n'}{s} g, \frac{m'}{s} q),用它们就可以生成出 S 中所有元素。同时这个集合满足,合并对应平移量的格子后,还剩 l = \frac{n' m' g}{s} 个变量格。

算不动集合恰好为 S 的类的数量比较困难,我们可以算不动集合包含 S 的类的数量,最后根据 S 的包含关系容斥。

可以发现,设 t_1 = \gcd(n', m', x, y), t_2 = \frac{l}{t_1},合并后的变量格,在平移意义下,等价于一个 t_1 \times t_2 的循环矩阵。证明是考虑平移量可以任意加减,坐标轴也可以进行 (x, y) \to (x + y, y), (x, y) \to (x, x + y), (x, y) \to (y, x) 的变换满足变换后仍然能一一对应,最终经过辗转相减可以凑出 (t_1, 0)

所以我们算出 t_1 \times t_2 的矩阵有多少个相似类,就是不动集合包含 S 的类的数量。用第二问的方法可以做到单次 O(d(t_1) d(t_2) \log t_1 t_2) 的复杂度。

最后我们还要进行容斥,即把所有集合按照 |S_i| 从大到小处理,对于 S_i \subsetneq S_j,令 ans_i 减去 ans_j,最后算出的 ans_i 就是不动集合恰好为 S_i 的类的数量。

考虑给两个三元组 (n'_1, m'_1, q_1), (n'_2, m'_2, q_2),怎么判断前者对应的 S 包含于后者。首先要满足 n'_2 \mid n'_1, m'_2 \mid m'_1。然后设两者的斜向分量分别为 (x_1, y_1), (x_2, y_2),检查 (x_1, y_1) 是否被后者生成。

考虑给定三元组 (n', m', q) 和一个平移量 (X, Y),如何判断 (X, Y) 能被 (n', m', q) 生成。首先因为有行周期 n' 和列周期 m',可以令 X \gets X \bmod n', Y \gets Y \bmod m'。设三元组对应的斜向分量为 (x, y),那么能被生成等价于 x \mid X\frac{X}{x} \cdot y \equiv Y \pmod {m'}

时间复杂度 O(\text{可过})

:::info[代码]

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
using ll = long long;
using ull = unsigned long long;
using db = double;
using ldb = long double;
using pii = pair<int, int>;
using pll = pair<ll, ll>;

namespace IO {
    const int maxn = 1 << 20;

    char ibuf[maxn], *iS, *iT, obuf[maxn], *oS = obuf;

    inline char gc() {
        return (iS == iT ? iT = (iS = ibuf) + fread(ibuf, 1, maxn, stdin), (iS == iT ? EOF : *iS++) : *iS++);
    }

    template<typename T = int>
    inline T read() {
        char c = gc();
        T x = 0;
        bool f = 0;
        while (c < '0' || c > '9') {
            f |= (c == '-');
            c = gc();
        }
        while (c >= '0' && c <= '9') {
            x = (x << 1) + (x << 3) + (c ^ 48);
            c = gc();
        }
        return f ? ~(x - 1) : x;
    }

    inline int reads(char *s) {
        char c = gc();
        int len = 0;
        while (isspace(c)) {
            c = gc();
        }
        while (!isspace(c) && c != EOF) {
            s[len++] = c;
            c = gc();
        }
        s[len] = '\0';
        return len;
    }

    inline string reads() {
        char c = gc();
        string s;
        while (isspace(c)) {
            c = gc();
        }
        while (!isspace(c) && c != EOF) {
            s += c;
            c = gc();
        }
        return s;
    }

    inline void flush() {
        fwrite(obuf, 1, oS - obuf, stdout);
        oS = obuf;
    }

    struct Flusher {
        ~Flusher() {
            flush();
        }
    } AutoFlush;

    inline void pc(char ch) {
        if (oS == obuf + maxn) {
            flush();
        }
        *oS++ = ch;
    }

    inline void write(char *s) {
        for (int i = 0; s[i]; ++i) {
            pc(s[i]);
        }
    }

    inline void write(const char *s) {
        for (int i = 0; s[i]; ++i) {
            pc(s[i]);
        }
    }

    template<typename T>
    inline void write(T x) {
        static char stk[64], *tp = stk;
        if (x < 0) {
            x = ~(x - 1);
            pc('-');
        }
        do {
            *tp++ = x % 10;
            x /= 10;
        } while (x);
        while (tp != stk) {
            pc((*--tp) | 48);
        }
    }

    template<typename T>
    inline void writesp(T x) {
        write(x);
        pc(' ');
    }

    template<typename T>
    inline void writeln(T x) {
        write(x);
        pc('\n');
    }
}

using IO::read;
using IO::reads;
using IO::write;
using IO::pc;
using IO::writesp;
using IO::writeln;

const int maxn = 1000100;
const int mod = 998244353;

inline int qpow(int b, int p) {
    int res = 1;
    while (p) {
        if (p & 1) {
            res = 1ULL * res * b % mod;
        }
        b = 1ULL * b * b % mod;
        p >>= 1;
    }
    return res;
}

inline void fix(int &x) {
    x += ((x >> 31) & mod);
}

int n, m, K, a[maxn], b[maxn], c[maxn], pr[maxn / 10], mpr[maxn], tot, tt, phi[maxn], inv[maxn];
bool vis[maxn];
pii d[99];

void dfs(int d, int x, vector<int> &di) {
    if (d > tt) {
        di.pb(x);
        return;
    }
    for (int i = 0; i <= ::d[d].scd; ++i, x *= ::d[d].fst) {
        dfs(d + 1, x, di);
    }
}

inline vector<int> getd(int n) {
    vector<int> di;
    tt = 0;
    while (n > 1) {
        int t = mpr[n], k = 0;
        while (n % t == 0) {
            n /= t;
            ++k;
        }
        d[++tt] = pii(t, k);
    }
    dfs(1, 1, di);
    return di;
}

inline int calc(int n, int m) {
    auto dn = getd(n), dm = getd(m);
    int ans = 0;
    for (int x : dn) {
        for (int y : dm) {
            int z = x * y / __gcd(x, y);
            fix(ans += 1ULL * phi[x] * phi[y] * qpow(c[z], n * m / z) % mod - mod);
        }
    }
    ans = 1ULL * ans * inv[n * m] % mod;
    return ans;
}

struct node {
    int n, m, q, x, y, l, w;
    node(int _n = 0, int _m = 0, int _q = 0, int _x = 0, int _y = 0, int _l = 0, int _w = 0) : n(_n), m(_m), q(_q), x(_x), y(_y), l(_l), w(_w) {}
} f[99999];

inline bool check(node &a, int x, int y) {
    x %= a.n;
    y %= a.m;
    if (x % a.x) {
        return 0;
    }
    int t = x / a.x;
    return a.y * t % a.m == y;
}

inline bool check(node &a, node &b) {
    if (a.n % b.n || a.m % b.m) {
        return 0;
    }
    return check(b, a.x, a.y);
}

void solve() {
    n = read();
    m = read();
    K = read();
    for (int i = 1; i <= K; ++i) {
        a[i] = read();
    }
    for (int i = 1; i <= K; ++i) {
        if (vis[i]) {
            continue;
        }
        int cnt = 0, u = i;
        do {
            ++cnt;
            vis[u] = 1;
            u = a[u];
        } while (u != i);
        b[cnt] += cnt;
    }
    for (int i = 1; i <= n * m; ++i) {
        for (int j = i; j <= n * m; j += i) {
            c[j] += b[i];
        }
    }
    mems(vis, 0);
    phi[1] = inv[1] = 1;
    for (int i = 2; i <= n * m; ++i) {
        inv[i] = 1ULL * (mod - mod / i) * inv[mod % i] % mod;
    }
    for (int i = 2; i <= n * m; ++i) {
        if (!vis[i]) {
            pr[++tot] = i;
            mpr[i] = i;
            phi[i] = i - 1;
        }
        for (int j = 1; j <= tot && i * pr[j] <= n * m; ++j) {
            vis[i * pr[j]] = 1;
            mpr[i * pr[j]] = pr[j];
            if (i % pr[j] == 0) {
                phi[i * pr[j]] = phi[i] * pr[j];
                break;
            }
            phi[i * pr[j]] = phi[i] * (pr[j] - 1);
        }
    }
    tot = 0;
    for (int u = 1; u <= n; ++u) {
        if (n % u) {
            continue;
        }
        for (int v = 1; v <= m; ++v) {
            if (m % v) {
                continue;
            }
            int s = __gcd(u, v);
            for (int q = 0; q < s; ++q) {
                int g = __gcd(q, s);
                int x = u / s * g, y = v / s * q;
                int t1 = __gcd(s, __gcd(x, y));
                int cnt = u * v / s * g, t2 = cnt / t1;
                f[++tot] = node(u, v, q, x, y, cnt, calc(t1, t2));
            }
        }
    }
    sort(f + 1, f + tot + 1, [&](const node &a, const node &b) {
        return a.l < b.l;
    });
    int ans1 = 0, ans2 = 0;
    for (int i = 1; i <= tot; ++i) {
        for (int j = 1; j < i; ++j) {
            if (check(f[i], f[j])) {
                fix(f[i].w -= f[j].w);
            }
        }
        fix(ans1 += 1ULL * f[i].w * f[i].l % mod - mod);
        fix(ans2 += f[i].w - mod);
    }
    writeln(ans1);
    writeln(ans2);
}

int main() {
    int T = 1;
    // scanf("%d", &T);
    while (T--) {
        solve();
    }
    return 0;
}

:::