题解:CF2045H Missing Separators

· · 题解

考虑说直接贪心不太能做,所以考虑 DP。

考虑说我们选后缀比前缀容易,所以我们倒着 DP。

考虑说我们如果设 f_{i,j} 表示分了后 ij 段是否可行,我们发现 i,jf_{k,j - 1} 转移,但我们需要枚举 i,j,k 还有 k 最后一段的长度,感觉不太好优化。

考虑到原来的 DP 主要关心最后一段多长,所以我们考虑设 f_{i, j} 表示 [i, n] 这段后缀,最后一段为 [i,j] 最长的段数是多少。

然后我们枚举 i, j,然后考虑用 SA 求出来 [i, n][j + 1, n] 的最长公共前缀设为 s,那么就看如果 [i,i + \min(lcp, j - i + 1)] 这段是否小于 [j + 1, j + 1 + \min(lcp, j - i + 1)],如果是的话,那么 f_{i,j} 就是 \max_{k=j + 1 + \min(lcp, j - i + 1)} ^ {n} f_{j + 1, k},然后发现是后缀 \max,于是再开一个数组维护即可。

#include <bits/stdc++.h>

using namespace std;

const int N = 10010;

int n, m, p;
char s[5010];
pair<int, int> f[5010][5010];
pair<int, int> g[5010][5010];
int st[16][N], log_2[N];
int sa[N], rk[N], ork[N], id[N], cnt[N];

int queryM(int l, int r) {
    int L = log_2[r - l + 1];
    return min(st[L][l], st[L][r - (1 << L) + 1]);
}

int main() {
    scanf("%s", s + 1);
    n = strlen(s + 1);

    m = 255;
    for (int i = 1; i <= n; ++i) cnt[rk[i] = s[i]]++;
    for (int i = 1; i <= m; ++i) cnt[i] += cnt[i - 1];
    for (int i = n; i >= 1; --i) sa[cnt[rk[i]]--] = i;

    for (int w = 1; ; w <<= 1, m = p) {
        int cc = 0;
        for (int i = n - w + 1; i <= n; ++i) id[++cc] = i;
        for (int i = 1; i <= n; ++i)
            if (sa[i] > w) id[++cc] = sa[i] - w;

        memset(cnt, 0, sizeof(cnt));
        for (int i = 1; i <= n; ++i) cnt[rk[i]]++;
        for (int i = 1; i <= m; ++i) cnt[i] += cnt[i - 1];
        for (int i = n; i >= 1; --i) sa[cnt[rk[id[i]]]--] = id[i];

        memcpy(ork, rk, sizeof(ork));
        p = 0;
        for (int i = 1; i <= n; ++i)
            if (ork[sa[i]] == ork[sa[i - 1]] && ork[sa[i] + w] == ork[sa[i - 1] + w])
                rk[sa[i]] = p;
            else
                rk[sa[i]] = ++p;

        if (p == n) break;
    }

    for (int i = 1, k = 1; i <= n; ++i) {
        if (!rk[i]) continue;
        if (k) --k;
        // cout << sa[rk[i] - 1] << ' ' << i << ' ' << rk[i] << ' ' << sa[rk[i]] << ' ' << sa[rk[i] - 1] << ' ' << sa[i] << endl;
        while (s[i + k] == s[sa[rk[i] - 1] + k]) ++k;

        st[0][rk[i]] = k;
    }

    for (int i = 1; i <= 15; ++i)
        for (int j = 1; j <= n; ++j)
            st[i][j] = min(st[i - 1][j], st[i - 1][j + (1 << i - 1)]);
    log_2[1] = 0;
    for (int i = 2; i <= n; ++i) log_2[i] = log_2[i >> 1] + 1;

    for (int i = n; i >= 1; --i) {
        for (int j = i; j <= n; ++j) {
            if (j == n) {
                f[i][j] = {1, n + 1};
                break;
            }

            int l = rk[i], r = rk[j + 1];
            if (l > r) swap(l, r);
            int lcp = min(queryM(l + 1, r), j - i + 1);

            if (lcp == n - j) continue;

            if (lcp >= j - i + 1) {
                if (g[j + 1][j + j - i + 2].first + 1 > f[i][j].first)
                    f[i][j] = {g[j + 1][j + j - i + 2].first + 1, g[j + 1][j + j - i + 2].second};
            } else {
                if (s[i + lcp] < s[j + lcp + 1]) {
                    if (g[j + 1][j + lcp + 1].first + 1 > f[i][j].first)
                        f[i][j] = {g[j + 1][j + lcp + 1].first + 1, g[j + 1][j + lcp + 1].second};
                }
            }

            // cout << i << ' ' << j << ' ' << lcp << ' ' << f[i][j].first << "----" << ' ' << j + lcp + 1 << ' ' << g[j + 1][j + lcp + 1].first << endl;
        }

        for (int j = n; j >= i; --j) {
            g[i][j] = max(g[i][j + 1], {f[i][j].first, j});
        }
    }

    int ans = 0, id = 0;
    for (int i = 1; i <= n; ++i) {
        if (f[1][i].first > ans) {
            ans = f[1][i].first;
            id = i;
        }
    }

    printf("%d\n", ans);
    int l = 1;
    while (id <= n) {
        for (int i = l; i <= id; ++i) putchar(s[i]);
        putchar('\n');
        int tid = id;
        id = f[l][id].second;
        l = tid + 1;
    }

    return 0;
}