题解:P14617 [2019 KAIST RUN Fall] Hilbert' s Hotel

· · 题解

好题。

先考虑询问 2 怎么做,可以用乘法标记和加法标记维护每个团队的人员编号,新加一个团队时,若 k=0,则相当于给之前所有团队打上一个 \times 2 的乘标记,否则相当于打上一个 +k 的加标记,这部分可以用线段树维护,详见 P3373。查询时找出该团队的乘法标记 mul,和加法标记 add,若该团队 k=0,位置即为 (2x-1)mul+add,否则即为 (x-1)mul+add

再考虑询问 3,发现这个用数据结构不太好维护,于是考虑挖掘性质,假设我们暴力的话,每遇到一个 k=0 的团队,若 x 的位置为奇数,说明答案为当前团队,否则就将 x 除以二,继续往前查询。如果每次只跳到 k=0 的团队时,可以在 \log x 次就跳到 0,那么具体的,维护一个 k 的前缀和,上一个 k=0 的团队 las,设当前跳到团队 nw,如果 sum_{nw}-sum_k>x,就说明答案在这里面,一个 lower_bound 即可解决,否则将 x 减掉上式,继续往前查询。此外,当 x=0 时需要跳到最近的一个 k>0 的团队,否则会假。特别注意和团队 0 有关的 corner case。综上,时间复杂度为 O(n\log V)

赛时代码:(没想到居然还是首 A)

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=3e5+5,mod=1e9+7;
struct node{
    int mul,add;
}t[N<<2];
struct node0{
    int x,las,hav;
}s[N];
int sum[N];
void pd(int u){
    t[u<<1].mul=t[u<<1].mul*t[u].mul%mod;
    t[u<<1].add=(t[u<<1].add*t[u].mul+t[u].add)%mod;
    t[u<<1|1].mul=t[u<<1|1].mul*t[u].mul%mod;
    t[u<<1|1].add=(t[u<<1|1].add*t[u].mul+t[u].add)%mod;
    t[u].mul=1,t[u].add=0;
    return;
} 
int ask(int u,int l,int r,int x,int y){
    if(l==r){
        if(s[x].x==0) return (2*y-1)*t[u].mul+t[u].add;
        return (y-1)*t[u].mul+t[u].add;
    }
    pd(u);
    int mid=l+r>>1;
    if(x<=mid) return ask(u<<1,l,mid,x,y);
    return ask(u<<1|1,mid+1,r,x,y);
}
void upd(int u,int l,int r,int x,int y,int mul,int add){
    if(x<=l && r<=y){
        t[u].mul=t[u].mul*mul%mod;
        t[u].add=(t[u].add*mul+add)%mod;
        return;
    }
    pd(u);
    int mid=l+r>>1;
    if(x<=mid) upd(u<<1,l,mid,x,y,mul,add);
    if(mid<y) upd(u<<1|1,mid+1,r,x,y,mul,add);
    return;
}
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    int Q,n=3e5,g=0,las=0,hv=0;
    for(int i=1;i<=(n<<2);i++) t[i].mul=1;
    s[0].x=1;
    cin>>Q;
    while(Q--){
        int op;
        cin>>op;
        if(op==1){
            int x;
            cin>>x;
            g++;
            s[g].x=x;
            s[g].las=las;
            s[g].hav=hv;
            if(x==0) upd(1,0,n,0,g-1,2,0),las=g;
            else upd(1,0,n,0,g-1,1,x),hv=g;
            sum[g]=sum[g-1]+x;
        }
        if(op==2){
            int num,x;
            cin>>num>>x;
            cout<<ask(1,0,n,num,x)%mod<<"\n";
        }
        if(op==3){
            int x;
            cin>>x;
            int nw=g,ans=0;
            if(s[nw].x==0){
                if(x%2==1) ans=nw;
                x/=2;
            }
            else if(s[nw].x>x) ans=nw;
            while(ans==0 && nw){
                if(x==0){
                    ans=s[nw].hav;
                    break;
                }
                if(sum[nw]-sum[s[nw].las]>x){
                    ans=lower_bound(sum,sum+nw+1,sum[nw]-x)-sum;
                    break;
                }
                x-=sum[nw]-sum[s[nw].las];
                nw=s[nw].las;
                if(x%2==1){
                    ans=nw;
                    break;
                }
                x/=2;
            }
            cout<<ans<<"\n";
        } 
    }
    return 0;
}