QOJ 8351 Ruin the legend

· · 题解

https://qoj.ac/contest/1644/problem/8351

思路

首先我们可以容易发现,对于题目中这样的条件 |a_{p_i} - a_{p_{i+1}}| = k,考虑重新排序,将序列的 xx+k,x-k 链接,来转化为一条条的长链。

注意到由于序列 a 中没有相同元素,每个数 x 最多只与 x−kx+k,相连接,那么序列即由若干条不相交的链(总数 m)和孤立点组成的。

直接统计答案十分困难,我们考虑容斥定理。

我们设 f_{i,j} 表示对于前 i 条链,总共选择 j 条满足 |a_{p_i} - a_{p_{i+1}}| = k 的方案数。因为这个排列还可以进行重排,且我们选择了 j 个位置固定了下来,但是还有剩下的 n-j 个位置,所以最后 f_{m,j}\times (n-j)!

则有 Ans=\sum^{m}_{j=0}(-1)^j \times f_{m,j}\times (n-j)!

这个是可以预处理的,我们令 h_{i,j,0/1} 表示长度为 i 的链中选 j 条,且链的两个端点不相邻/相邻的方案数。

  • 不连接 ii+1

因为不连接 i+1,所以这就是一个新的块,那么就有 h_{i+1,j,0}=h_{i,j,0}+h_{i,j,1}

  • 连接 ii+1

我们加了一条新的 i+1 这个新边,那么 i+1i 的贡献其实是并到了一起,但是注意到我们可以把 i 所在的块“翻转”或“不翻转”,所以我们 h_{i,j,0} 要统计两次。所以更新就有 h_{i+1,j+1,1}=h_{i,j,1}+2h_{i,j,0}

对于第 i 条长度为 l 的链,我们可以从中选择 d 条边。则 f_{i,j+d} 的值由 f_{i-1,j} 转移而来,表示在前 i-1 条链选 j 条边的方案数,乘以在当前链中选 d 条边的方案数 (h_{l,d,0}+h_{l,d,1})

#include<bits/stdc++.h>
#define int long long 
#define rep(i,l,r) for(int i=l;i<=r;++i)
#define per(i,r,l) for(int i=r;i>=l;--i)
using namespace std;
const int N=5010,p=998244353;
int n,k;
vector<int>a,fac,v;
int h[N][N][2];
int f[N][N];
bitset<N>vis;
map<int,int> mp;
void init(){
    a.resize(n+2);
    fac.resize(n+2);
    fac[0]=1;
    rep(i,1,n) fac[i]=fac[i-1]*i%p;
}
signed main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    freopen("count.in","r",stdin);
    freopen("count.out","w",stdout);
    cin>>n>>k; init();
    rep(i,1,n){
        cin>>a[i];
        mp[a[i]]=i;
    }
    h[1][0][0]=1;
    rep(i,1,n){
        rep(j,0,i-1){
            int s0=h[i][j][0];
            int s1=h[i][j][1];
            h[i+1][j][0]=(s0+s1)%p;
            h[i+1][j+1][1]=(s0*2+s1)%p;
        }
    }
    rep(i,1,n){
        if(vis[i])continue;
        int x=a[i],len=0;
        for(;mp.count(x);x+=k){
            vis[mp[x]]=1;
            len++;
        }
        v.push_back(len);
    }
    int i=0;
    f[0][0]=1;
    int m=0;
    for(auto len:v){
        ++i;
        rep(j,0,m){
            rep(d,0,len-1){
                f[i][j+d]=(f[i][j+d]+f[i-1][j]*(h[len][d][0]+h[len][d][1]))%p;
            }
        }
        m+=len-1;
    }
    int ans=0;
    rep(j,0,m){
        int v=(f[i][j]*fac[n-j])%p;
        if(j%2==1){
            ans=(ans-v+p)%p;
        }else{
            ans=(ans+v)%p;
        }
    }
    cout<<ans<<'\n';
    return 0;
}