我大概这辈子都忘不了魔裁了

· · 题解

我大概这辈子都忘不了魔裁了。

神题。

将类型为 1 的人视作 a_i 个左括号,类型为 0 的人视为 a_i 个右括号,则问题的本质为括号匹配。

:::info[性质]{open}

v_i=\begin{cases}a_i & \text{type}=1 \\ -a_i & \text{type}=0 \end{cases}sv 的前缀和序列,tv 的后缀和序列。

一个类型为 0 的人 i 能存活,当且仅当 s_is_1,s_2,\dots,s_i 的最小值。一个类型为 1 的人 i 能存活,当且仅当 t_it_i,t_{i+1},\dots,t_n 的最小值。

:::

:::success[证明]{open}

以类型 0 为例。由定义可知,s_i 等于 i 之前未匹配的左括号个数减去未匹配的右括号个数。那么当且仅当 s_i 为前缀最小值时,才会产生无法匹配的右括号,i 才能存活。类型 1 同理。

:::

由上述性质可知,以 s_i 的全局最小值为界,最终存活的人一定是一段前缀类型 0 拼上一段后缀类型 1。考虑对两部分分别 dp,以前缀为例,定义状态 f_{i,j,k} 为:处理到第 i 个人,\sum a_i=j,且当前 s_i 与前缀最小值的差(即还缺的右括号个数)为 k 时的贡献和,则转移分三部分:

后缀同理。做完之后枚举断点合并即可。

直接做是 O(nm^3) 的,不可接受。容易发现转移 1、2 为对角线求和,转移 3 为对一个三角形区域求和,均可用前缀和优化。优化后复杂度 O(nm^2)

:::success[实现]{open}

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int MAXN = 510, MOD = 998244353;
int a[MAXN], p[MAXN], n, m;
int f[2][MAXN][MAXN], g[2][MAXN][MAXN], fa[MAXN][MAXN], ga[MAXN][MAXN];
int suma[MAXN][MAXN], sumb[MAXN][MAXN], sumc[MAXN][MAXN];
void add(int &x, int y){
    x = (x + y) % MOD;
    return;
}
signed main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    cin >> n >> m;
    for (int i = 1; i <= n; i++){
        cin >> a[i];
    }
    for (int i = 1; i <= n; i++){
        cin >> p[i];
    }
    f[0][0][0] = fa[0][0] = 1;
    for (int i = 1; i <= n; i++){
        int curr = i & 1, prev = (i - 1) & 1;
        memset(f[curr], 0, sizeof(f[curr]));
        if (a[i]){
            for (int j = 0; j <= m; j++){
                sumc[j][0] = f[prev][j][0];
                for (int k = 1; k <= j; k++){
                    sumc[j][k] = (sumc[j][k - 1] + f[prev][j][k]) % MOD;
                }
            }
            for (int j = a[i]; j <= m; j++){
                for (int k = 0; k <= j; k++){
                    if (k >= a[i]){
                        add(f[curr][j][k], f[prev][j - a[i]][k - a[i]]);
                    }
                    if (k + a[i] <= m){
                        add(f[curr][j][k], f[prev][j - a[i]][k + a[i]]);
                    }
                }
                int nk = min(j - a[i], a[i] - 1);
                if (nk >= 0){
                    add(f[curr][j][0], sumc[j - a[i]][nk] * p[i]);
                }
            }
        }
        else{
            for (int j = 0; j <= m; j++){
                for (int k = 0; k <= j; k++){
                    suma[j][k] = f[prev][j][k];
                    if (j && k){
                        add(suma[j][k], suma[j - 1][k - 1]);
                    }
                    sumb[j][k] = f[prev][j][k];
                    if (j && k < m){
                        add(sumb[j][k], sumb[j - 1][k + 1]);
                    }
                }
            }
            int tot = 0;
            for (int j = 1; j <= m; j++){
                add(tot, sumb[j - 1][0]);
                for (int k = 0; k <= j; k++){
                    if (k){
                        add(f[curr][j][k], suma[j - 1][k - 1]);
                    }
                    if (k < m){
                        add(f[curr][j][k], sumb[j - 1][k + 1]);
                    }
                }
                add(f[curr][j][0], tot * p[i]);
            }
        }
        for (int j = 0; j <= m; j++) {
            fa[i][j] = f[curr][j][0];
        }
    }
    g[(n + 1) & 1][0][0] = 1;
    for (int i = n; i >= 1; i--){
        int curr = i & 1, prev = (i + 1) & 1;
        memset(g[curr], 0, sizeof(g[curr]));
        if (a[i]){
            for (int j = 0; j <= m; j++){
                sumc[j][0] = g[prev][j][0];
                for (int k = 1; k <= j; k++){
                    sumc[j][k] = (sumc[j][k - 1] + g[prev][j][k]) % MOD;
                }
            }
            for (int j = a[i]; j <= m; j++){
                for (int k = 0; k <= j; k++){
                    if (k >= a[i]){
                        add(g[curr][j][k], g[prev][j - a[i]][k - a[i]]);
                    }
                    if (k + a[i] <= m){
                        add(g[curr][j][k], g[prev][j - a[i]][k + a[i]]);
                    }
                }
                int nk = min(j - a[i], a[i] - 1);
                if (nk >= 0){
                    add(g[curr][j][0], sumc[j - a[i]][nk] * p[i]);
                    ga[i][j] = sumc[j - a[i]][nk] * p[i] % MOD;
                }
            }
        }
        else{
            for (int j = 0; j <= m; j++){
                for (int k = 0; k <= j; k++){
                    suma[j][k] = g[prev][j][k];
                    if (j && k){
                        add(suma[j][k], suma[j - 1][k - 1]);
                    }
                    sumb[j][k] = g[prev][j][k];
                    if (j && k < m){
                        add(sumb[j][k], sumb[j - 1][k + 1]);
                    }
                }
            }
            int tot = 0;
            for (int j = 1; j <= m; j++){
                add(tot, sumb[j - 1][0]);
                for (int k = 0; k <= j; k++){
                    if (k){
                        add(g[curr][j][k], suma[j - 1][k - 1]);
                    }
                    if (k < m){
                        add(g[curr][j][k], sumb[j - 1][k + 1]);
                    }
                }
                add(g[curr][j][0], tot * p[i]);
                ga[i][j] = tot * p[i] % MOD;
            }
        }
    }
    int ans = fa[n][m];
    for (int i = 0; i < n; i++){
        for (int j = 0; j <= m; j++){
            add(ans, fa[i][j] * ga[i + 1][m - j]);
        }
    }
    cout << ans << "\n";
    return 0;
}

:::