题解:P9307 「DTOI-5」进行一个排的重 (Maximum Version)

· · 题解

二维前缀和喜欢取模被卡飞了。

思路

先考虑最大值怎么求。发现在最优情况中,每个 a_i 都至少产生 1 的贡献。这是由于将未产生贡献的 a_i 提前是不劣的。那么可以直接以 p_i 为关键字排序,则答案为 n 加上 q 序列的最长上升子序列长度。同时我们也可以由此得到任意答案一定会包含此时 q 的任意一个最长上升子序列。接下来的讨论基于此时的 q

考虑如何求方案数。设 f_i 为以 q_i 为结尾的最长上升子序列长度,g_i 表示以 i 为开头的后缀中,i 在一个最长上升子序列的方案数。那么若对于 j>i,其 f_j 可以由 f_i 贡献得到,那么可以将满足 1\le k<i,q_i<q_k<q_j 的元素按顺序放入,再把满足 i<k<j,q_k<q_i 的元素插入。用二维前缀和统计一下这两种元素的个数,然后直接组合数计算转移就行。

复杂度 O(n^2)

代码

#include<bits/stdc++.h>
#define ll long long
#define ull usigned long long
#define ld long double
#define stg string
using namespace std;
int read() {
    int r=0,flag=1;char c = getchar();
    while(c<'0'||c>'9') {if(c=='-') flag=-1;c=getchar();}
    while(c<='9'&&c>='0') {r=(r<<1)+(r<<3)+(c-'0');c=getchar();}
    return r*flag;
}
const int N=1e4+7;
const ll mod=998244353;
ll fac[N+5],inf[N+5],inv[N+5];
void init() {
    inv[1]=fac[0]=inf[0]=1;
    for(int i=2;i<=N;i++) inv[i]=((mod-mod/i)*inv[mod%i])%mod;
    for(int i=1;i<=N;i++) fac[i]=(fac[i-1]*i)%mod,inf[i]=(inf[i-1]*inv[i])%mod;
}
inline ll C(int n,int m) {
    if(n<m||n<0||m<0) return 0;
    return ((fac[n]*inf[m])%mod*inf[n-m])%mod;
}
ll ksm(int a,int b) {
    ll res=1;
    while(b) {if(b&1) res=(res*a)%mod;a=(a*a)%mod;b>>=1;}
    return res;
}
ll dp[N],sm[N][N],n,ans[N];
inline ll gsm(int x1,int y1,int x2,int y2) {
    //cout<<"#"<<x1<<" "<<y1<<" "<<x2<<" "<<y2<<endl;
    if(x1<=0||y1<=0||x2<=0||y2<=0) return 0;
    return sm[x2][y2]+sm[x1-1][y1-1]-sm[x1-1][y2]-sm[x2][y1-1];
}
struct node {
    int p,q;
    bool operator <(const node& x) const {
        if(p==x.p) return q<x.q;
        else return p<x.p;
    }
} a[N];
int main() {
    init();
    cin>>n;
    for(int i=1;i<=n;i++) a[i].p=read();
    for(int i=1;i<=n;i++) a[i].q=read();
    sort(a+1,a+n+1);
    a[n+1].q=n+1;
    for(int i=1;i<=n+1;i++) {
        ans[i]=1;
        for(int j=1;j<i;j++) if(a[i].q>a[j].q) ans[i]=max(ans[i],ans[j]+1);
    }
    for(int i=1;i<=n;i++) {
        for(int j=1;j<=n;j++) sm[i][j]=sm[i][j-1]+sm[i-1][j]-sm[i-1][j-1]+(a[i].q==j);
    }
    dp[n+1]=1;
    for(int i=n;i>=0;i--) {
        for(int j=i+1;j<=n+1;j++) {
            if(a[j].q>a[i].q&&ans[j]==ans[i]+1) {
                int cnt=gsm(1,a[i].q+1,i-1,a[j].q-1);
                dp[i]=(dp[i]+((dp[j]*C(cnt+gsm(i+1,1,j-1,a[i].q),cnt))%mod))%mod;
            //cout<<i<<" "<<j<<" "<<t<<" "<<a[i].q<<endl;
            }
        }
    }
    cout<<ans[n+1]+n-1<<" "<<dp[0];
}