题解:P16928 聚魔石

· · 题解

好题。

我们将数组按照 s 分为左右两边,一边是 1 \sim s - 1,另一边是 s + 1 \sim n

然后注意到我们 1 \sim s 肯定是 s \rightarrow 1 这个顺序传染的,所以我们设 a_i = s - i

类似的,我们设 b_i = s + i

此时我们设 f_{i, j} 为扩展了 a_i \sim b_j 这些的“共鸣”的期望次数。

我们使用刷表发,看看 f_{i, j} 能带来哪些贡献。

如果 i < nj < n,那么向 f_{i + 1, j} 扩展的概率就是 \frac{w_{a_{i + 1}}}{w_{a_{i + 1}} + w_{b_{j + 1}}},同理也可以计算右边。

那么 f_{i + 1, j} \leftarrow f_{i + 1, j} + f_{i, j} \times \frac{w_{a_{i + 1}}}{w_{a_{i + 1}} + w_{b_{j + 1}}}

同理,f_{i, j + 1} \leftarrow f_{i, j + 1} + f_{i, j} \times \frac{w_{b_{j + 1}}}{w_{a_{i + 1}} + w_{b_{j + 1}}}

如果 i = nj < n,则 f_{i, j + 1} \leftarrow f_{i, j + 1} + f_{i, j}

如果 i < nj = n,则 f_{i + 1, j} \leftarrow f_{i + 1, j} + f_{i, j}

考虑答案统计,在任意一次转移过程中,假设我们从 (x_1, y_1) 转移到 x_2, y_2,转移式子是 f_{x_2, y_2} \leftarrow f_{x_2, y_2} + k \times f_{x_1, y_1}

如果 x_2 = x_1 + 1 并且 p_{a_{x_2}}p_{a_{x_1}} \sim p_{b_{x_2}} 的前缀最大,那么我们将答案加上 k \times f_{x_1, y_1}

同理也可以得到如果 y_2 = y_1 +1 的结果。

最后输出答案加一,原因是第一次也算共鸣。

code

#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define int long long
#define pii pair<int, int>
#define piii pair<pii, int>
#define pll pair<ll, ll>
#define plll pair<pll, ll>
#define fi first
#define se second
const int N = 2e3 + 5, M = 1e6 + 5;
const int inf = 1e9, mod = 998244353;
const ll INF = 1e18;

namespace ARIS0_0{
    int n, s, p[N], w[N], a[N], b[N];
    int mxa[N], mxb[N], dp[N][N];

    int qpow(int a, int b){
        int res = 1;
        a %= mod;
        while (b){
            if (b & 1) res = 1ll * res * a % mod;
            a = 1ll * a * a % mod, b >>= 1;
        }
        return res;
    }

    void init(){
    }
    void solve(){
        cin >> n >> s;
        for (int i = 1; i <= n; i ++ ) cin >> p[i];
        for (int i = 1; i <= n; i ++ ) cin >> w[i], w[i] %= mod;

        int l = s - 1, r = n - s;
        for (int i = 1; i <= l; i ++ ) a[i] = s - i;
        for (int i = 1; i <= r; i ++ ) b[i] = s + i;

        for (int i = 1; i <= l; i ++ ) mxa[i] = max(mxa[i - 1], p[a[i]]);
        for (int i = 1; i <= r; i ++ ) mxb[i] = max(mxb[i - 1], p[b[i]]);

        dp[0][0] = 1;
        int ans = 0, st = p[s];
        for (int i = 0; i <= l; i ++ )
            for (int j = 0; j <= r; j ++ ){
                int cur = dp[i][j] % mod;
                cur = (cur % mod + mod) % mod;
                if (i == l && j == r) continue;

                int mx = st;
                if (i) mx = max(mx, mxa[i]);
                if (j) mx = max(mx, mxb[j]);

                if (i < l && j < r){
                    int xa = a[i + 1], xb = b[j + 1];
                    int inv = qpow((w[xa] % mod + w[xb] % mod) % mod, mod - 2);
                    dp[i + 1][j] = (dp[i + 1][j] + 1ll * cur * (w[xa] % mod) % mod * inv % mod) % mod;
                    dp[i][j + 1] = (dp[i][j + 1] + 1ll * cur * (w[xb] % mod) % mod * inv % mod) % mod;
                    dp[i + 1][j] = (dp[i + 1][j] % mod + mod) % mod;
                    dp[i][j + 1] = (dp[i][j + 1] % mod + mod) % mod;

                    if (p[xa] > mx) ans = (ans + 1ll * cur * (w[xa] % mod) % mod * inv % mod) % mod;
                    if (p[xb] > mx) ans = (ans + 1ll * cur * (w[xb] % mod) % mod * inv % mod) % mod;
                }
                else if (j == r){
                    int xa = a[i + 1];
                    dp[i + 1][j] = (dp[i + 1][j] + cur) % mod;
                    if (p[xa] > mx) ans = (ans + cur) % mod;
                }
                else if (i == l){
                    int xb = b[j + 1];
                    dp[i][j + 1] = (dp[i][j + 1] + cur) % mod;
                    if (p[xb] > mx) ans = (ans + cur) % mod;
                }
                ans = (ans % mod + mod) % mod;
            }

        cout << (ans + 1) % mod << "\n";
    }
    void single(){ init(), solve(); }
    void multi(){ init(); int T; cin >> T; while (T -- ) solve(); }
    void idmulti(){ init(); int id, T; cin >> id >> T; while (T -- ) solve(); }
};

signed main(){
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    ARIS0_0::single();
}