HT-75-NOI-T1 题解

· · 题解

题意

给出一个长为 n 的数字串 s,求其中有多少对等长子串 A,B,使得两串转为数字后 b=a+1,认为不同位置的子串不同。n\le 4\times 10^5

题解

先不考虑进位,则两串只有结尾一位不同。考虑枚举 a 末位取值 x,会得到两个集合 S,T,分别记录原串中所有 xx+1 的位置。这时从两个集合中各取一个位置 p,q,一定可以配对,产生至少 1 的贡献。还要考虑长为 p-1q-1 的两个前缀,两者的公共后缀可接在变化位前面,构成更长的合法对。因此还有最长公共后缀长度的额外贡献。因此总贡献为 \left|S\right|\cdot\left|T\right|+\sum_{p\in S,q\in T} w(p-1,q-1),其中 w(i,j) 表示长度分别为 i,j 的两个前缀的最长公共后缀长度。

前面一项容易计算,后面一项考虑在反串上建后缀数组维护所有前缀,并求出其 height 数组(以下简记为 h)。则贡献转化为 \sum_{p\in S,q\in T} w'(rk_{p-1},rk_{q-1}),其中 rk 是后缀数组中的排名,w'(i,j) 表示 i,j 之间 h 数组左开右闭的区间 min,即 \min_{k=\min(i,j)+1}^{\max(i,j)} h_k,也就是上面的最长公共后缀长度。这个考虑扫 h 数组,并维护一个递增的单调栈,每个点上记录对应区间内 rk_{p-1},rk_{q-1} 的个数,同时记录两种位置到当前点区间 min 的和,这样单调栈删点时也能修改当前值。此时若当前点也是合法位置,就给答案加上另一种位置的贡献和即可。若拿出所有关键点后,将中间每部分压成区间 min 一个数,可以做到单次复杂度关于关键点数线性,但需要对 h 数组 O(n\log n) 预处理出 ST 表,以备查询区间 min。

现在还需要考虑进位,注意到以某个 9a 的结尾时,进位会将后缀极长的一段 9 全变为 0,再将上一位增加 1。这样对应的 b 有同样数量的后缀 0,且上一位一定非 0,因此该 0 后缀也是极长的。同时 a 不能每一位均为 9,否则 b 无法满足长度限制。考虑拿出所有不在开头且无法向前扩展的 90 连续段,它们的总个数是 O(n) 级别的,因为一个长为 c 的极长连续段会产生 c 个段。找出所有长度相等的段,以它们左边一位为 p,q,则要求该位满足加的限制,所求的式子和上面一样,同样求一下即可。时间复杂度 O(n\log n),在后缀数组和 ST 表预处理,事实上两个均可以科技做到线性

代码

#include<iostream>
#include<algorithm>
#include<vector>
#define ll long long
#define pb push_back
#define pii pair<int,int>
#define fi first
#define se second
using namespace std;
const int N=4e5+10;
const int K=20+5;
int n,a[N],m,p,sa[N],rk[N],id[N],cnt[N],tk[N<<1],h[N],c[N],d[N];
string s; ll res;
vector <int> pos[K],poa[11][N],pob[11][N];
vector <pii> tv;
struct STt
{
    int w[K][N],lg[N],po[K];
    void build()
    {
        lg[0]=-1,po[0]=1;
        for(int i=1;i<=n;i++) w[0][i]=h[i],lg[i]=lg[i>>1]+1;
        for(int i=1;i<=20;i++) po[i]=(po[i-1]<<1);
        for(int k=1;k<=20;k++) for(int i=1;i+po[k]-1<=n;i++)
            w[k][i]=min(w[k-1][i],w[k-1][i+po[k-1]]);
    }
    int query(int l,int r)
    {
        int k=lg[r-l+1];
        return min(w[k][l],w[k][r-po[k]+1]);
    }
}T;
int tw[N],kd[N],cn,st[N],hd,aa[N][2];
void solve()
{
    sort(tv.begin(),tv.end()),cn=0; int lp=0;
    for(pii te:tv) if(te.fi>=1&&te.fi<=n)
    {
        if(te.fi>lp+1) ++cn,tw[cn]=T.query(lp+1,te.fi-1),kd[cn]=-1;
        ++cn,tw[cn]=h[te.fi],kd[cn]=te.se,lp=te.fi;
    }
    ll cur[2]={0,0}; hd=0;
    for(int i=cn;i;i--)
    {
        if(kd[i]!=-1) res+=cur[kd[i]^1];
        ll tc[2]={0,0};
        while(hd&&tw[i]<=tw[st[hd]])
        {
            int tp=st[hd]; hd--;
            for(int o=0;o<2;o++) cur[o]-=1ll*aa[tp][o]*tw[tp],tc[o]+=aa[tp][o];
        }
        if(kd[i]!=-1) tc[kd[i]]++;
        for(int o=0;o<2;o++) cur[o]+=1ll*tc[o]*tw[i],aa[i][o]=tc[o];
        st[++hd]=i;
    }
}
int main()
{
    ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
    cin>>n>>s,p=10,a[0]=d[0]=-1;
    for(int i=1;i<=n;i++) a[i]=s[i-1]-'0',pos[a[i]].pb(i);
    for(int i=1;i<=n;i++)
    {
        c[i]=(a[i]==9?c[i-1]+1:0),d[i]=(a[i-1]==9?d[i-1]:a[i-1]);
        if(c[i]&&d[i]!=-1) poa[d[i]][c[i]].pb(i);
    }
    for(int i=1;i<=n;i++)
    {
        c[i]=(!a[i]?c[i-1]+1:0),d[i]=(!a[i-1]?d[i-1]:a[i-1]);
        if(c[i]&&d[i]!=-1) pob[d[i]][c[i]].pb(i);
    }
    reverse(a+1,a+1+n);
    for(int i=1;i<=n;i++) cnt[rk[i]=a[i]+1]++;
    for(int i=2;i<=p;i++) cnt[i]+=cnt[i-1];
    for(int i=n;i>=1;i--) sa[cnt[rk[i]]--]=i;
    for(int w=1;;w<<=1)
    {
        m=p,p=0; int ct=0;
        for(int i=n-w+1;i<=n;i++) id[++ct]=i;
        for(int i=1;i<=n;i++) if(sa[i]>w) id[++ct]=sa[i]-w;
        for(int i=1;i<=m;i++) cnt[i]=0;
        for(int i=1;i<=n;i++) cnt[rk[i]]++,tk[i]=rk[i];
        for(int i=2;i<=m;i++) cnt[i]+=cnt[i-1];
        for(int i=n;i>=1;i--) sa[cnt[rk[id[i]]]--]=id[i];
        for(int i=1;i<=n;i++)
        {
            if(tk[sa[i]]!=tk[sa[i-1]]||tk[sa[i]+w]!=tk[sa[i-1]+w]) p++;
            rk[sa[i]]=p;
        }
        if(p==n) break;
    }
    int k=0;
    for(int i=1;i<=n;i++)
    {
        if(k) k--;
        int x=sa[rk[i]-1];
        while(x+k<=n&&i+k<=n&&a[x+k]==a[i+k]) k++;
        h[rk[i]]=k;
    }
    T.build();
    for(int i=0;i<9;i++) if(pos[i].size()&&pos[i+1].size())
    {
        tv.clear(),res+=1ll*pos[i].size()*pos[i+1].size();
        for(int x:pos[i]) tv.pb({rk[n-(x-1)+1],0});
        for(int x:pos[i+1]) tv.pb({rk[n-(x-1)+1],1});
        solve();
    }
    for(int i=1;i<=n;i++) for(int j=0;j<9;j++) if(poa[j][i].size()&&pob[j+1][i].size())
    {
        tv.clear(),res+=1ll*poa[j][i].size()*pob[j+1][i].size();
        for(int x:poa[j][i]) tv.pb({rk[n-(x-i-1)+1],0});
        for(int x:pob[j+1][i]) tv.pb({rk[n-(x-i-1)+1],1});
        solve();
    }
    cout<<res;
    return 0;
}