题解:P13757 【MX-X17-T6】Selection

· · 题解

先钦定前 k 个数组大于后 n-k 个数组,最后乘上 \binom{n}{k}

钦定前 k 个数组每个位置的最小值,后 n-k 个要求小于等于这个最小值数组,且由于要严格小于,不能有前 k 个、后 n-k 个都有等于这个最小值数组的。

那么要求的问题转化为,总方案数,减去“都有等于”的方案数。

总方案数就是拆成每个位置,枚举最小值,为:

(\sum_{i=1}^v ((v-i+1)^k - (v-i)^k) i^{n-k})^m

这个式子是一个关于 vn+1 次多项式,不难插值求出。

要减去的部分,枚举前 k 个、后 n-k 个等于的个数,分别为 i,j。同样拆成每个位置贡献独立,可得:

\sum_{i=1}^k\sum_{j=1}^{n-k}(-1)^{i+j}\binom{k}{i}\binom{n-k}{j}(\sum_{z=1}^{v} z^{k-i} (v-z+1)^{n-k-j})^m

现在问题在于对于每个 i,j 求出 \sum_{z=1}^{v} z^{i} (v-z+1)^{j}

设二元 EGF:

F =\sum_{i,j} \frac{\sum_{z=1}^{v} z^{i} (v-z+1)^{j}}{i!j!} x^i y^j

对两维求导,可得:

\frac{dF}{dx} =\sum_{i,j} \frac{\sum_{z=1}^{v} z^{i+1} (v-z+1)^{j}}{i!j!} x^i y^j \frac{dF}{dy} =\sum_{i,j} \frac{\sum_{z=1}^{v} z^{i} (v-z+1)^{j+1}}{i!j!} x^i y^{j} \frac{dF}{dx} + \frac{dF}{dy} = (v+1) F

提取两边 [x^iy^j] 系数:

(v+1)f_{i,j} = (i+1)f_{i+1,j} + (j+1)f_{i,j+1}

据此递推即可 O(n^2) 求出所有 f_{i,j}。初值需要用到 f_{i,0} = \sum_{z=1}^{v}z^i,这是自然数幂和,可以线性插值。

总复杂度 O(n^2\log m)

struct fval{
    vector<modint>x,y;
    void ins(modint X,modint Y){x.pb(X),y.pb(Y);}
    void init(){x.clear(),y.clear();}
    modint val(modint k){
        int n=x.size();
        modint res=0;
        For(i,0,n-1){
            modint s1=1,s2=1;
            For(j,0,n-1)if(i!=j)s1*=(k-x[j]),s2*=(x[i]-x[j]);
            res+=y[i]*s1/s2;
        }
        return res;
    }
}F;

int n,m,k,v;
modint f[4005][4005],sum[8005];

namespace CF{
    // O(k) 
    modint ml[maxn],mr[maxn];
    modint solve(int n,int k){
        if(n<=k+5){
        modint zz=0;
        For(i,1,n)zz+=qpow(i,k);return zz;}
        For(i,0,k+4)ml[i]=mr[i]=0;
        ml[0]=mr[k+3]=1;
        For(i,1,k+2) ml[i]=ml[i-1]*(n-i);
        Rep(i,k+2,1)mr[i]=mr[i+1]*(n-i);
        modint res=0,y=0;
        For(i,1,k+2){
            y+=modint(i)^k;
            modint a=ml[i-1]*mr[i+1];
            modint b=ifac[i-1]*ifac[k+2-i];
            if((k-i)&1)b=-b;
            res+=y*a*b;
        }
        return res;
    }
}

void work(int O)
{
    n=read(),m=read(),k=read(),v=read();
    modint res=0;

    modint all=0;

    F.init();
    For(vv,1,n+3){
        modint sv=0;
        For(i,1,vv){
            modint tmp=qpow(vv-i+1,k);
            tmp*=(qpow(i,n-k)-qpow(i-1,n-k));
            sv+=tmp;
        }
    //  cout<<"vv,sv "<<vv<<" "<<sv.x<<"\n";
        F.ins(vv,sv);
    }

    all=F.val(v);

//  cout<<"all "<<v<<' '<<all.x<<"\n";
//  For(i,1,v){
//      modint tmp=qpow(v-i+1,k);
//      tmp*=(qpow(i,n-k)-qpow(i-1,n-k));
//      all+=tmp;
//  }

    res=qpow(all,m);
//  cout<<"res "<<res.x<<"\n";

    // need: for i,j, f[i][j]=\sum qpow(v-z,i)*qpow(z,j)
    // (v+1)*f[i][j] = f[i+1][j]*(i+1) + f[i][j+1]*(j+1)
    // f[i][j+1] * (j+1) = 
    auto calc=[&](int i,int j){
//      modint tmp=0;
//      For(z,1,v) tmp+=qpow(v-z+1,i)*qpow(z,j);
        modint tmp=0;
        For(z,1,v) tmp+=qpow(z,i);
        return tmp;
    };

    For(i,0,n){
        f[i][0]=CF::solve(v,i);
        //f[i][0]=calc(i,0);
        if(i==0) f[i][0]=v;
        f[i][0]*=ifac[i];
//      F.init();
//      For(j,1,i+1){
//          sum[j]=sum[j-1]+qpow(j,i);
//      }
//      For(j,0,i+1) F.ins(j,sum[j]);
//      f[i][0]=F.val(v)*ifac[i];
    //  f[i][0]=calc(i,0)*ifac[i];
    //  cout<<"i,0 "<<i<<" "<<0<<" "<<f[i][0].x<<"\n";
    }

    For(j,0,n-1){
        For(i,0,n){
            // 
            f[i][j+1]=(f[i][j]*(v+1)-f[i+1][j]*(i+1));
            f[i][j+1]*=iv[j+1];
    //      cout<<"i,j "<<i<<" "<<j+1<<" "<<f[i][j+1].x<<"\n";
        }
    }

    For(i,1,k) For(j,1,n-k) {
        modint tmp=0;
        tmp=f[k-i][n-k-j];
        tmp*=fac[k-i]*fac[n-k-j];
    //  For(z,1,v) tmp+=qpow(v-z+1,k-i)*qpow(z,n-k-j);
    //  cout<<"tmp "<<tmp.x<<"\n";
        tmp=qpow(tmp,m);
        tmp*=C(k,i)*C(n-k,j);
        res-=tmp*sign(i+j);
    }
    res*=C(n,k);
    cout<<res.x<<"\n";
}

signed main()
{
    initC(5005);
    int T=read();
    For(_,1,T)work(_);
    return 0;
}