题解:P4484 [BJWC2018] 最长上升子序列

· · 题解

解题思路

问题分析

对于一个长度为

n

的随机排列,总共有

n!

种可能的排列。我们需要计算所有这些排列的 LIS 长度的平均值。

由于

n

最大为 28 ,直接枚举所有排列并计算 LIS 是不现实的( 28! 是一个非常大的数)。因此,我们需要更高效的方法。

算法思想

程序采用了动态规划结合组合数学的方法:

枚举所有可能的 LIS 长度

k

计算具有 LIS 长度为

k

的排列的数量 用加权平均计算期望值:

∑(k×count(k))/n!

核心是通过深度优先搜索 (DFS) 枚举所有可能的整数分拆,这些分拆对应着 LIS 的可能结构,然后计算每个结构对期望的贡献。

程序解析

核心数据结构 使用 mint 结构体处理模

998244353

的运算,包括加减乘除和幂运算 采用分数形式计算,避免浮点数精度误差 主要函数

主函数:

预处理逆元数组 计算

n

(存储在 fac 中) 调用 DFS 函数枚举所有可能的分拆

枚举整数 $n

的所有分拆方式 每个分拆对应一种可能的 LIS 结构

计算当前分拆结构对期望的贡献 通过组合数学计算符合该结构的排列数量 将贡献累加到答案中 ### 关键算法 程序的核心是通过整数分拆来表示 $LIS$ 的可能长度和结构: 对于每个分拆 $seq=[a 1 ​,a 2 ​,...,a k ​]

,表示 LIS 长度为

k

计算这种分拆对应的排列数量,再乘以

k

就是对期望的贡献 所有分拆的贡献之和除以

n!

就是最终期望

代码实现

#include <bits/stdc++.h>
using namespace std;
// #define int long long
#define rep(i, j, k) for(int i = (j); i <= (k); i++)
#define per(i, j, k) for(int i = (j); i >= (k); i--)
#define pb emplace_back
#define fi first
#define se second
using vi = vector<int>;
using pi = pair<int, int>;

template<typename T0, typename T1> bool chmin(T0 &x, const T1 &y){
    if(y < x){x = y; return true;} return false;
}
template<typename T0, typename T1> bool chmax(T0 &x, const T1 &y){
    if(x < y){x = y; return true;} return false;
}

template<typename T> void debug(char *s, T x){
    cerr << s <<" = "<< x <<endl;
}
template<typename T, typename ...Ar> void debug(char *s, T x, Ar... y){
    int dep = 0;
    while(!(*s == ',' && dep == 0)){
        if(*s == '(') dep++;
        if(*s == ')') dep--;
        cerr << *s++;
    }
    cerr <<" = "<< x <<",";
    debug(s + 1, y...);
}
#define gdb(...) debug((char*)#__VA_ARGS__, __VA_ARGS__)

using u32 = uint32_t;
using u64 = uint64_t;
constexpr int mod = 998244353;
struct mint{
    u32 x;

    mint(): x(0){}
    mint(int _x){
        _x %= mod;
        if(_x < 0) _x += mod;
        x = _x;
    }

    u32 val()const {
        return x;
    }
    mint qpow(int y = mod - 2)const {
        assert(y >= 0);
        mint x = *this, res = 1;
        while(y){
            if(y%2) res *= x;
            x *= x;
            y /= 2;
        }
        return res;
    }

    mint& operator += (const mint &B){
        if((x += B.x) >= mod) x -= mod;
        return *this;
    }
    mint& operator -= (const mint &B){
        if((x -= B.x) >= mod) x += mod;
        return *this;
    }
    mint& operator *= (const mint &B){
        x = (u64)x * B.x % mod;
        return *this;
    }
    mint& operator /= (const mint &B){
        return *this *= B.qpow();
    }
    friend mint operator + (const mint &A, const mint &B){
        return mint(A) += B;
    }
    friend mint operator - (const mint &A, const mint &B){
        return mint(A) -= B;
    }
    friend mint operator * (const mint &A, const mint &B){
        return mint(A) *= B;
    }
    friend mint operator / (const mint &A, const mint &B){
        return mint(A) /= B;
    }
    mint operator - ()const {
        return mint() - *this;
    }
};

signed main(){
    #ifdef LOCAL
    freopen(".in", "r", stdin);
    freopen(".out", "w", stdout);
    #endif
    ios::sync_with_stdio(0);
    cin.tie(0);

    int n;
    cin >> n;

    vector<mint> inv(n + 1);
    rep(i, 1, n){
        inv[i] = mint(1) / i;
    }

    mint ans = 0, fac = 1;
    vi seq;
    rep(i, 1, n){
        fac *= i;
    }
    auto slv = [&](){
        mint coef = 1;
        vi sum(n + 1);
        per(i, (int)seq.size() - 1, 0){
            rep(j, 0, seq[i] - 1){
                coef *= inv[seq[i] - j + sum[j]];
                sum[j]++;
            }
        }
        ans += seq[0] * (coef * coef * fac);
    };

    auto dfs = [&](auto &self, int r, int mx){
        if(r == 0){
            slv();
            return;
        }
        chmin(mx, r);
        rep(i, 1, mx){
            seq.pb(i);
            self(self, r - i, i);
            seq.pop_back();
        }
    };

    dfs(dfs, n, n);
    cout << ans.val() <<'\n';
}