题解:P14846 [ICPC 2022 Yokohama R] Incredibly Cute Penguin Chicks/ ARIS0_0 - 15

· · 题解

各种神秘 DP 写累了,来写一道普通点的 BIT 优化 DP。

题意

给定一个仅由三种字符构成的字符串,问有多少种方法,可将其划分成若干个连续区间,满足区间内存在两种字符的数目相同,且另一个字符的数目是最大的。

思路

我们维护一个前缀和 pre_{i,j} 表示前缀区间 [1,i] 内第 j 种字符的出现次数。那么对于一个划分的合法区间 [l,r] 就需要满足:

pre_{r,x}-pre_{l-1,x}=pre_{r,y}-pre_{l-1,y}\\pre_{r,x}-pre_{l-1,x}<pre_{r,z}-pre_{l-1,z}

其中 x,y,z 互不相同。

于是就可以搓出一个 O(n^2) 的 DP 了,设 dp_i 表示前缀 [1,i] 有多少种划分方法,那么对于 dp_i 的转移就是枚举每个 j,若 [j+1,i] 是合法的,就加上 dp_{j} 的贡献。这么做显然是会超时的,需要优化。

我们摆弄一下合法区间需要满足的等式与不等式,就可以得到:

pre_{r,x}-pre_{r,y}=pre_{l-1,x}-pre_{l-1,y}\\pre_{r,x}-pre_{r,z}<pre_{l-1,x}-pre_{l-1,z}

不难发现这是一个二维偏序,我们可以对等式分组,将所有 pre_{i,x}-pre_{i,y} 的值相同的 i 分为一组,用 vector 存下它们 pre_{i,x}-pre_{i,z} 的值并离散化。

注意到对于每个 dp_i,能对其做贡献的 dp_j 需要与它是同一组且 pre_{j,x}-pre_{j,z} 要比 i 的大。因此转移本质上是在某组值域上求后缀和,于是就可以对每组都建 BIT 来维护,BIT 的总长是 O(n) 的。

注意到 x,y,z 的排列是有三种情况的,因此每组需针对每种情况建立三个 BIT。特别地,如果前缀区间 [1,i] 合法也可以产生一个 1 的贡献。

这么做的时间复杂度是 O(n\log n),可以通过。

代码

比较丑,感性理解吧……

#include<bits/stdc++.h>
#define cin_fast ios::sync_with_stdio(false) , cin.tie(0) , cout.tie(0)
#define fi first
#define se second
//#define int long long 
#define in(a) a = read()
#define rep(i , a , b) for(int i = a ; i <= b ; i ++)
using namespace std;
typedef long long ll;
const int N = 1e6 + 5 , mod = 998244353;
const int inf = 0x3f3f3f3f;
const long long INF = 0x3f3f3f3f3f3f3f3f; 
inline int read() {
    int x = 0;
    char ch = getchar();
    bool f = 0;
    while('9' < ch || ch < '0') f |= ch == '-' , ch = getchar();
    while('0' <= ch && ch <= '9') x = (x << 3) + (x << 1) + ch - '0' , ch = getchar();
    return f ? -x : x;
}
string s;
ll dp[N];
int c(char x) {
    if(x == 'C') return 0;
    else if(x == 'I') return 1;
    else return 2;
}
int pre[N][3] , b[N] , tot , n;
vector<ll>v[N][3] , BIT[N][3];
void add(int o , int p , int x , int val) {
    for(int i = x ; i <= v[o][p].size() ; i += i & -i) BIT[o][p][i] += val , BIT[o][p][i] %= mod;
}
ll sum(int o , int p , int x) {
    ll h = 0;
    for(int i = x ; i ; i -= i & -i) h += BIT[o][p][i] , h %= mod;
    return h;
} 
int id[N][3];
void solve(int x , int y , int z) {
    tot = 0;
    for(int i = 1 ; i <= n ; i ++) b[++ tot] = pre[i][x] - pre[i][y];
    sort(b + 1 , b + tot + 1) , tot = unique(b + 1 , b + tot + 1) - b - 1;
    for(int i = 1 ; i <= n ; i ++) {
        id[i][z] = lower_bound(b + 1 , b + tot + 1 , pre[i][x] - pre[i][y]) - b;
        v[id[i][z]][z].push_back(pre[i][x] - pre[i][z]);
    }
    for(int i = 1 ; i <= tot ; i ++) {
        sort(v[i][z].begin() , v[i][z].end());
        unique(v[i][z].begin() , v[i][z].end()) - v[i][z].begin();
        BIT[i][z].resize(v[i][z].size() + 5);
    }
}
signed main() {
    //cin_fast;
    cin >> s;
    n = s.size();
    s = '#' + s;
    for(int i = 1 ; i <= n ; i ++) {
        pre[i][0] = pre[i - 1][0] , pre[i][1] = pre[i - 1][1] , pre[i][2] = pre[i - 1][2];
        pre[i][c(s[i])] ++;
    }
    solve(0 , 1 , 2) , solve(2 , 0 , 1) , solve(1 , 2 , 0);
    for(int i = 1 ; i <= n ; i ++) {
        for(int j = 0 ; j < 3 ; j ++) {
            int now = lower_bound(v[id[i][j]][j].begin() , v[id[i][j]][j].end() , pre[i][(j + 1) % 3] - pre[i][j]) - v[id[i][j]][j].begin() + 1;
            dp[i] += (sum(id[i][j] , j , v[id[i][j]][j].size()) - sum(id[i][j] , j , now) + mod) % mod;
            if(pre[i][(j + 1) % 3] == pre[i][(j + 2) % 3] && pre[i][(j + 1) % 3] < pre[i][j]) dp[i] ++;
            dp[i] %= mod;
        }
        for(int j = 0 ; j < 3 ; j ++) {
            int now = lower_bound(v[id[i][j]][j].begin() , v[id[i][j]][j].end() , pre[i][(j + 1) % 3] - pre[i][j]) - v[id[i][j]][j].begin() + 1;
            add(id[i][j] , j , now , dp[i]);   
        }
    }
    cout << dp[n];
    return 0;
}
/*
我听见
飘走的流年任风吹
冬至的白雪在纷飞
忙忙碌碌又一年
记忆的碎片又积成堆
风雪的消散要多快
才能将光阴拦下
也许
这银装素裹的世界
隔天就会融化
this is ARIS 15
*/