P14808 [CCPC 2024 哈尔滨站] 子序列计数

· · 题解

k \gets k^{-1} \bmod L

我们发现子序列的 DP 可以用矩阵描述。与其去算子序列,我们不妨去算一个广义的信息:把 s_i 变成一个 (m + 1) \times (m + 1) 的矩阵,然后依次把 i = 0 \sim L - 1s_{ik \bmod L} 矩阵乘起来,求最终的矩阵。

为了直观理解这个问题,我们可以把所有位置排成一个 k 列网格图,第 i 列(i \in [0, k - 1])从上到下分别是 i, i + k, i + 2k, \cdots, i + ck,满足 i + ck < Li + (c + 1)k \ge L,即再走一步就要取模。

发现我们走的顺序是先从第 0 列开始,每次走完完整的一列,然后设走完了第 x 列,跳到第 (x + k - L) \bmod k 列继续走。

那么我们把每一列的矩阵乘积算出来,就可以递归到 L' = kk' = k - L \bmod k 的子问题。递归终止条件为 L = 1

同时若初始的 sn 段,那么算每一列的矩阵乘积得到的 s' 只有 n + O(1) 段。

要求出对应的段以及矩阵乘积是简单的,我们开一棵线段树维护新的每一段的矩阵乘积,那么原本 s 的一段(设其为 [l, r],对应矩阵为 a)的作用相当于,若 \left\lfloor\frac{l}{k}\right\rfloor = \left\lfloor\frac{r}{k}\right\rfloor 相当于给 [l \bmod k, r \bmod k] 区间乘上 a,否则给从 l \bmod k 开始的一段前缀和 r \bmod k 结尾的一段后缀乘上 a,然后整体乘上 a^{\left\lfloor\frac{r}{k}\right\rfloor - \left\lfloor\frac{l}{k}\right\rfloor}

有个小问题,若 L, k 非常接近,那么递归的复杂度不对。

我们做一个修正:若 2k > L,令 k \gets L - k,那么相当于是把 s 除了第一位 reverse。这样递归次数就是 O(\log L) 了。

总时间复杂度 O(m^3 n \log n \log L),常数很小。

:::info[代码]

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
using ll = long long;
using ull = unsigned long long;
using db = double;
using ldb = long double;
using pii = pair<int, int>;
using pll = pair<ll, ll>;

const int maxn = 2020;
const int mod = 998244353;

inline void fix(int &x) {
    x += ((x >> 31) & mod);
}

int n, m, K, L, a[12], b[maxn << 1];

void exgcd(int a, int b, int &x, int &y) {
    if (!b) {
        x = 1;
        y = 0;
        return;
    }
    exgcd(b, a % b, y, x);
    y -= a / b * x;
}

inline int inv(int a, int p) {
    int x, y;
    exgcd(a, p, x, y);
    return (x % p + p) % p;
}

struct mat {
    int a[12][12];
    mat() {
        mems(a, 0);
    }
} I;

inline mat operator * (const mat &a, const mat &b) {
    mat res;
    for (int i = 0; i <= m; ++i) {
        for (int k = 0; k <= m; ++k) {
            if (!a.a[i][k]) {
                continue;
            }
            for (int j = 0; j <= m; ++j) {
                if (!b.a[k][j]) {
                    continue;
                }
                fix(res.a[i][j] += 1ULL * a.a[i][k] * b.a[k][j] % mod - mod);
            }
        }
    }
    return res;
}

inline mat qpow(mat a, int p) {
    mat res = I;
    while (p) {
        if (p & 1) {
            res = res * a;
        }
        a = a * a;
        p >>= 1;
    }
    return res;
}

namespace SGT {
    mat a[maxn << 2];
    bool vis[maxn << 2];

    void build(int rt, int l, int r) {
        a[rt] = I;
        vis[rt] = 0;
        if (l == r) {
            return;
        }
        int mid = (l + r) >> 1;
        build(rt << 1, l, mid);
        build(rt << 1 | 1, mid + 1, r);
    }

    inline void pushtag(int x, const mat &y) {
        a[x] = a[x] * y;
        vis[x] = 1;
    }

    inline void pushdown(int x) {
        if (!vis[x]) {
            return;
        }
        pushtag(x << 1, a[x]);
        pushtag(x << 1 | 1, a[x]);
        a[x] = I;
        vis[x] = 0;
    }

    void update(int rt, int l, int r, int ql, int qr, const mat &x) {
        if (ql > qr) {
            return;
        }
        if (ql <= l && r <= qr) {
            pushtag(rt, x);
            return;
        }
        pushdown(rt);
        int mid = (l + r) >> 1;
        if (ql <= mid) {
            update(rt << 1, l, mid, ql, qr, x);
        }
        if (qr > mid) {
            update(rt << 1 | 1, mid + 1, r, ql, qr, x);
        }
    }

    void dfs(int rt, int l, int r, vector<pair<int, mat>> &vc) {
        if (l == r) {
            vc.pb(b[l + 1] - b[l], a[rt]);
            return;
        }
        pushdown(rt);
        int mid = (l + r) >> 1;
        dfs(rt << 1, l, mid, vc);
        dfs(rt << 1 | 1, mid + 1, r, vc);
    }
}

mat work(int l, int k, vector<pair<int, mat>> vc) {
    if (l == 1) {
        return vc[0].scd;
    }
    if (k * 2 > l) {
        mat A = vc[0].scd;
        reverse(vc.begin(), vc.end());
        if (vc.back().fst == 1) {
            vc.pop_back();
        } else {
            --vc.back().fst;
        }
        vc.insert(vc.begin(), mkp(1, A));
        return work(l, l - k, vc);
    }
    int n = (int)vc.size();
    b[1] = 0;
    for (int i = 0; i < n; ++i) {
        b[i + 2] = (b[i + 1] + vc[i].fst) % k;
    }
    sort(b + 1, b + n + 2);
    n = unique(b + 1, b + n + 2) - b - 1;
    b[n + 1] = k;
    int s = 0;
    SGT::build(1, 1, n);
    for (auto p : vc) {
        int l = s, r = s + p.fst - 1;
        s += p.fst;
        int x = lower_bound(b + 1, b + n + 1, l % k) - b, y = lower_bound(b + 1, b + n + 1, (r + 1) % k) - b - 1;
        if (l / k == (r + 1) / k) {
            SGT::update(1, 1, n, x, y, p.scd);
        } else {
            SGT::update(1, 1, n, x, n, p.scd);
            SGT::update(1, 1, n, 1, n, qpow(p.scd, (r + 1) / k - l / k - 1));
            SGT::update(1, 1, n, 1, y, p.scd);
        }
    }
    vector<pair<int, mat>> nv;
    SGT::dfs(1, 1, n, nv);
    return work(k, k - l % k, nv);
}

void solve() {
    scanf("%d%d%d%d", &n, &m, &K, &L);
    for (int i = 1; i <= m; ++i) {
        scanf("%d", &a[i]);
    }
    K = inv(K, L);
    for (int i = 0; i <= m; ++i) {
        I.a[i][i] = 1;
    }
    vector<pair<int, mat>> vc;
    for (int i = 1, x, y; i <= n; ++i) {
        scanf("%d%d", &x, &y);
        mat A = I;
        for (int j = 1; j <= m; ++j) {
            if (y == a[j]) {
                A.a[j - 1][j] = 1;
            }
        }
        vc.pb(x, A);
    }
    printf("%d\n", work(L, K, vc).a[0][m]);
}

int main() {
    int T = 1;
    // scanf("%d", &T);
    while (T--) {
        solve();
    }
    return 0;
}

:::