题解:AT_abc434_g [ABC434G] Keyboard

· · 题解

更好的阅读体验

Update on 2025/12/04:修改了一处笔误。

使用了一个较为优美的写法,较短而且可能不容易写错。

单侧递归线段树!

首先,如果一个字符串里含有形如 x\texttt{B} 的结构,其中 x 是一个数字,那么这个 x 一定会被删掉。假设 S 删完会变成 f(S)。那么由于不存在形如 x\texttt{B} 的结构了,所以一定是前面若干个 \texttt{B},后面一堆数字。

我们希望用线段树维护 f(S)。肯定不能维护整个字符串,我们尝试把前面的一串 \texttt{B} 和后面的一堆数字分别维护起来。我们需要前面 \texttt{B} 的个数 b,后面数字串的长度 l,后面数字串的值 x

然后我们会发现一个严峻的问题,就是在合并两个区间的时候,后面的 \texttt{B} 会把前面的数字删掉一些。

做过楼房重建就对单侧递归这个手法不陌生了。这个信息其实可以在 O(\log n) 的时间内完成合并。

我们在合并的时候,我们需要知道左子树删掉右子树的 b 之后剩下什么数。我们可以假设 g(p, k) 表示 p 子树的最后 k 个数字是什么,在合并的时候把左子树扣掉一个 g 再并到右子树上就可以。

这时就很简单了,求 g 的时候如果右子树的数字个数 \ge k 那么就递归右子树,否则将 k 减掉的数字个数然后去左子树找,然后再并上右子树剩下的数字。这样求 g 可以做到 O(\log n)

区间查询同样也可以这么做。我们先把要查询的区间对应的线段树节点抠出来,然后由于在上述合并过程中,我们需要左孩子的节点编号,所以我们从右往左合并,就可以了。

单点修改是简单的。

那么这道题就做完了,复杂度 O(n \log n + q \log^2 n),前半部分是建树复杂度。

#include<bits/stdc++.h>
#define int long long
#define endl '\n'
#define N 8000006
#define MOD 998244353
using namespace std;
int n,q,pw[N],ipw[N]; char ch[N];
inline int qpow(int x,int y=MOD-2)
{
    int ret=1;
    for(;y;y>>=1,x=x*x%MOD)if(y&1)ret=ret*x%MOD;
    return ret;
}
struct Node {int b,l,x;};
Node node(int x,int y,int z){return {x,y,z};}
struct Segtree {
    Node tree[N<<2];
    //b: number of prefix 'B'
    //l: length of numbers
    //x: value of numbers
    int dfs(int p,int k)
    {
        if(!k)return 0;
        if(tree[p].l==k)return tree[p].x;
        if(tree[p<<1|1].l>=k)return dfs(p<<1|1,k);
        int m=k-tree[p<<1|1].l,b=tree[p<<1|1].b,y=dfs(p<<1,m+b);
        int z=(tree[p<<1].x-(tree[p].x-tree[p<<1|1].x)*ipw[tree[p<<1|1].l]%MOD*pw[b]%MOD+MOD)%MOD;
        return ((y-z+MOD)*ipw[b]%MOD*pw[tree[p<<1|1].l]+tree[p<<1|1].x)%MOD;
    }
    void push_up(Node &p,int ls,Node rs)
    {
        p.b=tree[ls].b+max(0ll,rs.b-tree[ls].l);
        p.l=max(0ll,tree[ls].l-rs.b)+rs.l;
        int num=tree[ls].l>rs.b?(tree[ls].x-dfs(ls,rs.b)+MOD)*ipw[rs.b]%MOD:0;
        p.x=(num*pw[rs.l]%MOD+rs.x)%MOD;
    }
    void build(int p,int l,int r)
    {
        if(l==r)return tree[p]=(ch[l]=='B'?node(1,0,0):node(0,1,ch[l]^48)),(void)0;
        int mid=l+r>>1; build(p<<1,l,mid),build(p<<1|1,mid+1,r);
        push_up(tree[p],p<<1,tree[p<<1|1]);
    }
    void update(int p,int l,int r,int k)
    {
        if(l==r)return tree[p]=(ch[l]=='B'?node(1,0,0):node(0,1,ch[l]^48)),(void)0;
        int mid=l+r>>1; k<=mid?update(p<<1,l,mid,k):update(p<<1|1,mid+1,r,k);
        push_up(tree[p],p<<1,tree[p<<1|1]);
    }
    vector<int> vec;
    void query(int p,int l,int r,int L,int R)
    {
        if(L<=l&&r<=R)return vec.push_back(p); int mid=l+r>>1;
        if(R>mid)query(p<<1|1,mid+1,r,L,R); if(L<=mid)query(p<<1,l,mid,L,R);
    }
    int query(int l,int r)
    {
        vec.clear(),query(1,1,n,l,r);
        Node ret=node(-1,-1,-1),tmp=node(-1,-1,-1);
        for(int p:vec)ret.b==-1?(ret=tree[p]):(push_up(tmp,p,ret),ret=tmp);
        return ret.l?ret.x:-1;
    }
} T;
main()
{
    scanf("%lld%lld%s",&n,&q,ch+1),pw[0]=1;
    for(int i=1;i<N;i++)pw[i]=10*pw[i-1]%MOD;
    ipw[N-1]=qpow(pw[N-1]);
    for(int i=N-2;~i;i--)ipw[i]=10*ipw[i+1]%MOD;
    T.build(1,1,n);
    while(q--)
    {
        int opt,l,r,k; char t[3]; scanf("%lld",&opt);
        if(opt==1)scanf("%lld%s",&k,t+1),ch[k]=t[1],T.update(1,1,n,k);
        else scanf("%lld%lld",&l,&r),printf("%lld\n",T.query(l,r));
    }
    return 0;
}