题解:P16812 [蓝桥杯 2026 国 Python A] 压缩字符串

· · 题解

按照 \texttt{\#} 分段,把每段本质不同子序列的 OGF 乘起来就能得到答案。

求长为 n01s 的本质不同子序列个数,设 F_{i,0/1}[1,i] 前缀中以 0/1 结尾的本质不同子序列的 OGF,考虑 s_i 的值,有:

可以写成矩乘的形式:

\begin{bmatrix}F_{i,0}&F_{i,1}&1\end{bmatrix}=\begin{bmatrix}F_{i-1,0}&F_{i-1,1}&1\end{bmatrix}A_{0/1}\\ A_0=\begin{bmatrix}x&0&0\\x&1&0\\x&0&1\end{bmatrix}\qquad A_1=\begin{bmatrix}1&x&0\\0&x&0\\0&x&1\end{bmatrix}

分治下去乘即可,复杂度为 \mathcal O(n\log^2n),带个 27 倍常数,一跑一个不吱声。

考虑改写成仿射变换的形式:

\begin{bmatrix}F_{i,0}&F_{i,1}\end{bmatrix}=\begin{bmatrix}F_{i-1,0}&F_{i-1,1}\end{bmatrix}A_{0/1}+b_{0/1}\\ A_0=\begin{bmatrix}x&0\\x&1\end{bmatrix}\qquad b_0=\begin{bmatrix}x&0\end{bmatrix}\\ A_1=\begin{bmatrix}1&x\\0&x\end{bmatrix}\qquad b_1=\begin{bmatrix}0&x\end{bmatrix}

合并时 (fA+b)C+d=fAC+bC+d,需要 12 次多项式乘法,18 次 NTT,最终得到的 b 就是答案。

一次合并要 36 次 NTT(直接乘)的代码:

// 間違ったまま 息をし続け
int lim = 1e9; struct node { poly a00, a01, a10, a11, b0, b1; };
il node operator*(const node &a, const node &b) {
    node c = {
        a.a00 * b.a00 + a.a01 * b.a10,
        a.a00 * b.a01 + a.a01 * b.a11,
        a.a10 * b.a00 + a.a11 * b.a10,
        a.a10 * b.a01 + a.a11 * b.a11,
        a.b0 * b.a00 + a.b1 * b.a10 + b.b0,
        a.b0 * b.a01 + a.b1 * b.a11 + b.b1
    };
    auto res = [&](poly &f) -> void { if (f.size() > lim) f.resize(lim); };
    res(c.a00), res(c.a01), res(c.a10), res(c.a11), res(c.b0), res(c.b1);
    return c;
}
il poly calc(const string &s) {
    if (s.empty()) return {1};
    auto sol = [&](auto &self, int l, int r) -> node {
        if (l == r) {
            if (s[l] == '0') return {_x, _0, _x, _1, _x, _0};
            else return {_1, _x, _0, _x, _0, _x};
        }
        int mid = (l + r) >> 1;
        return self(self, l, mid) * self(self, mid + 1, r);
    };
    node f = sol(sol, 0, s.length() - 1);
    return f.b0 + f.b1;
}
int main() {
    int n = rd(), k = rd(); string s = rdst(), t; k -= count(s.begin(), s.end(), '#'), lim = k + 1;
    if (k < 0) return cout << 0, 0; s += '#'; vector<poly> vec;
    for (int i = 0; i <= n; i++) if (s[i] != '#') t += s[i]; else vec.emplace_back(calc(t)), t.clear();
    using ptr = poly*; auto cmp = [](ptr a, ptr b) { return a->size() > b->size(); };
    priority_queue<ptr, vector<ptr>, decltype(cmp)> Q(cmp); for (auto &f : vec) Q.push(&f);
    auto res = [&](poly a) -> poly { if (a.size() > lim) a.resize(lim); return a; };
    while (Q.size() > 1) {
        auto f = Q.top(); Q.pop(); auto g = Q.top(); Q.pop();
        ptr h = new poly(res(*f * *g)); Q.push(h);
    } cout << (*Q.top())[k];
}

改成 18 次 NTT 后的代码:

// 間違ったまま 息をし続け
int lim = 1e9; struct node {
    poly a00, a01, a10, a11, b0, b1;
    il int size() const { return max({a00.size(), a01.size(), a10.size(), a11.size(), b0.size(), b1.size()}); }
    il void DIF(int _n) { ::DIF(_n, a00), ::DIF(_n, a01), ::DIF(_n, a10), ::DIF(_n, a11), ::DIF(_n, b0), ::DIF(_n, b1); }
    il void DIT(int _n) { ::DIT(_n, a00), ::DIT(_n, a01), ::DIT(_n, a10), ::DIT(_n, a11), ::DIT(_n, b0), ::DIT(_n, b1); }
    il void resize(int n) { a00.resize(n), a01.resize(n), a10.resize(n), a11.resize(n), b0.resize(n), b1.resize(n); }
};
il node operator*(node a, node b) {
    int n = a.size() + b.size() - 1, _n = get(n);
    a.DIF(_n), b.DIF(_n); node c; c.resize(_n);
    for (int i = 0; i < _n; i++)
        c.a00[i] = ((ll)a.a00[i] * b.a00[i] + (ll)a.a01[i] * b.a10[i]) % p,
        c.a01[i] = ((ll)a.a00[i] * b.a01[i] + (ll)a.a01[i] * b.a11[i]) % p,
        c.a10[i] = ((ll)a.a10[i] * b.a00[i] + (ll)a.a11[i] * b.a10[i]) % p,
        c.a11[i] = ((ll)a.a10[i] * b.a01[i] + (ll)a.a11[i] * b.a11[i]) % p,
        c.b0[i] = ((ll)a.b0[i] * b.a00[i] + (ll)a.b1[i] * b.a10[i] + b.b0[i]) % p,
        c.b1[i] = ((ll)a.b0[i] * b.a01[i] + (ll)a.b1[i] * b.a11[i] + b.b1[i]) % p;
    c.DIT(_n), c.resize(min(n, lim)); return c;
}