题解 P3477 【[POI2008]PER-Permutation】

· · 题解

给出一个长度为 n 的序列 {S},

现将这个序列的所有不同 排列按字典序排序,求 {S} 的排名模 m 的值。

也就是求又多少个排列{A}字典序<{S}。

考虑每一位的贡献。

对于第i位,我们要求{A}的前i-1位和{S}相同,而且A[i]<S[i]

枚举A[i]=x,x是<S[i]的所有能选的数

贡献=(n-i)!/(a1!*a2!*..an!)*ax

其中a1,a2..an是i..n中n个不同的数的各自的出现次数

从后往前考虑i,记录每个数出现次数,前半部分是容易维护的。

对于x这个部分,显然x的和就是<S[i]的所有的数的个数。

这可以用树状数组维护。

注意m不一定是质数,除法就不能用逆元转成乘法了,

所以要先把m分解质因数,算出每个mod pi^ai下的答案,再用CRT合并。

注意这里可能有数是pi的倍数,必须将每个数表示成a*pi^b的形式,否则就可能变成0。

时间O(logm*nlogn)

(竟然t了一个点..)

#include<bits/stdc++.h>

const int N=300100,U=300000;  
#define ll long long
int n,m,a[N];
int base[40],st[40],top;

void exgcd(int a,int b,int &x,int &y)//solve ax+by=1
{
    if(!b) {x=1;y=0;return ;}
    exgcd(b,a%b,y,x);
    y-=a/b*x;
}
int niv(int x,int p)//x^-1 %p
{
    int a,b;
    exgcd(x,p,a,b);
    return (a%p+p)%p;
}

int c[N];
int cnt[N];
void add(int i)
{
    for(;i<=U;i+=i&-i) ++c[i];
}
int qiu(int i)
{
    int ans=0;
    for(;i;i-=i&-i) ans=ans+c[i];
    return ans;
}

int p,num;

namespace kcz
{
    int a,b;
    void zhuan(int x)//x->a*p^b
    {
        b=0;
        while(!(x%p)) {x/=p;++b;}
        a=x;
    }
}

int p_mi[N];//p^i%m
int solve(int m)
{
    int i;
    p_mi[0]=1;
    for(i=1;i<=U;++i) p_mi[i]=p_mi[i-1]*p%m;
    for(i=1;i<=U;++i) c[i]=cnt[i]=0;
    int ans=1,ax=1;num=0;//ax=(n-i)!/(a1!*..an!) ax实际上=ax*p^num 
    cnt[a[n]]=1;add(a[n]);

    for(i=n-1;i;--i)
    {
        kcz::zhuan(n-i);
        num+=kcz::b;
        ax=(ll)ax*kcz::a%m;

        kcz::zhuan(++cnt[a[i]]);
        num-=kcz::b;
        ax=(ll)ax*niv(kcz::a,m)%m;
        add(a[i]);

        ans=(ans+(ll)ax*qiu(a[i]-1)%m*p_mi[num])%m; 
    }
    return ans;
}

int main()
{ freopen("1.in","r",stdin);freopen("1.out","w",stdout);
    scanf("%d%d",&n,&m);
    int i;
    for(i=1;i<=n;++i) scanf("%d",a+i);

    int x=m;
    for(i=2;i*i<=x;++i)
    if(!(x%i))
    {
        base[++top]=i;st[top]=1;
        while(!(x%i)){x/=i;st[top]*=i;} 
    }
    if(x>1){++top;base[top]=st[top]=x;}

    int ans=0;
    for(i=1;x=st[i];++i) 
    {
        p=base[i];
     int a=solve(x);
     int M=m/x;
     ans=(ans+(ll)a*M%m*niv(M,x))%m;
    }

    printf("%d\n",ans);
}