题解:AT_abc405_g [ABC405G] Range Shuffle Query

· · 题解

不会 E 但会 G 是什么水平……

拿到题,发现是 q 次询问没有修改,结合 2.5\times 10^5 的不明数据范围,启示我们正解可能是莫队。

首先先想如何统计答案,如果区间出现了 x 个比 X 小的数,那么有 x! 种排列方式,但是这样算会算重。具体来说,对于每个数 i,如果它出现了 cnt_i 次,则它的 cnt_i! 种不同的排列方式都被统计到了答案里。所以正确的答案是 \frac{x}{\prod cnt_i}

那我们就可以莫队了,用树状数组可以实现 O(\log n) 的维护,但是这样时间复杂度是 O(n\sqrt n \log n),是无法通过的。于是我们再仔细想想我们要干什么:

注意到这三个操作中,后两个只会执行 O(n) 次(视 n,q 同阶,以下也如此),但第一个会执行 n\sqrt n 次,如果我们都用 O(\log n) 的时间复杂度来执行,那么看起来有些亏。如果我们能有一个可以 O(1) 实现第一个操作,O(\sqrt n) 实现后两个操作的东西就好了。

这个东西当然是有的,它的名字叫分块。具体来说,我们将值域分为 O(\sqrt n) 块,在块上维护块内元素的和以及积。每次单点修改的时候,我们修改单点,并修改整块的和以及积。查询的时候,我们遍历之前的所有块,累计答案,再遍历最后剩下的一点,累计答案。这样就达到了我们的目标。

于是做完了,代码十分好写,但是有点卡常,如果卡不过去可以试试不要 #define int long long

#include<bits/stdc++.h>
#define N 250005
using namespace std;
const int mod=998244353;
int n,m,c[N],siz,sum,cnt[N],f[N],inv[N],ans[N];
struct yhb{
    int l,r,x,id;
}a[N];
int cmp(yhb x,yhb y){
    if(x.l/siz!=y.l/siz)return x.l<y.l;
    return (x.l/siz)&1?x.r<y.r:x.r>y.r;
}
int id(int x){
    return (x-1)/siz+1;
}
struct bl{
    int c[N],f[N];
    bl(){
        for(int i=1;i<N;++i)c[i]=f[i]=1;
    }
    void add(int x,int k){
        c[x]=1ll*c[x]*k%mod;f[id(x)]=1ll*f[id(x)]*k%mod;
    }
    int ask(int x){
        int lim=id(x),ans=1;
        for(int i=1;i<lim;++i)ans=1ll*ans*f[i]%mod;
        for(int i=(lim-1)*siz+1;i<=x;++i)ans=1ll*ans*c[i]%mod;
        return ans;
    }
}A;
struct bl2{
    int c[N],f[N];
    bl2(){
        for(int i=1;i<N;++i)c[i]=f[i]=0;
    }
    void add(int x,int k){
        c[x]+=k;f[id(x)]+=k; 
    }
    int ask(int x){
        int lim=id(x),ans=0;
        for(int i=1;i<lim;++i)ans+=f[i];
        for(int i=(lim-1)*siz+1;i<=x;++i)ans+=c[i];
        return ans;
    }
}B;
void add(int x){
    B.add(x,1);A.add(x,B.c[x]);
}
void del(int x){
    A.add(x,inv[B.c[x]]);B.add(x,-1);
}
int qpow(int x,int y){
    int ans=1;
    while(y){
        if(y&1)ans=1ll*ans*x%mod;
        x=1ll*x*x%mod;y>>=1;
    }
    return ans;
}
signed main(){
    ios::sync_with_stdio(0);cin.tie(0);
    cin>>n>>m;siz=sqrt(n);
    for(int i=1;i<=n;++i)cin>>c[i];
    f[0]=inv[1]=1;
    for(int i=1;i<=n;++i)f[i]=1ll*f[i-1]*i%mod;
    for(int i=2;i<=n;++i)inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
    for(int i=1;i<=m;++i)cin>>a[i].l>>a[i].r>>a[i].x,a[i].id=i;
    sort(a+1,a+m+1,cmp);
    int l=2,r=1;
    for(int i=1;i<=m;++i){
        while(l>a[i].l)add(c[--l]);
        while(r<a[i].r)add(c[++r]);
        while(l<a[i].l)del(c[l++]);
        while(r>a[i].r)del(c[r--]);
        ans[a[i].id]=1ll*f[B.ask(a[i].x-1)]*qpow(A.ask(a[i].x-1),mod-2)%mod;
    }
    for(int i=1;i<=m;++i)cout<<ans[i]<<'\n';
    return 0;
}