题解:P15650 [省选联考 2026] 摩卡串 / string

· · 题解

首先 B 性质显然是一边填一边记极长 0 连续段长度(对 n 取较小值)和小于的个数,C 性质也能得出一个类似的做法。所以我们考虑一个个填数,填一个数后我们需要关心有多少个后缀字典序严格小于给定串(已经出现不同且第一个不同的位置字典序较小),字典序大致等于给定串(未出现不同,只是长度小)的位置匹配都到了哪里。

我们发现字典序大致等于的极长后缀确定了以后更短的后缀都确定了,前面的我们只关心严格小于的个数。更美妙的事情是这个极长后缀其实是当前跑 kmp 匹配到的位置,所以可以通过 kmp 自动机(一个串建出的 AC 自动机)实现 O(1) 转移。同时子串的限制就转化为了 kmp 匹配成功过。

我们可以记 dp_{0/1,i,j,k} 表示是否匹配成功过,kmp 跑到了哪里,前面有多少个严格小于,目前有多少个小于。转移形式是枚举填的字符跑一个边权为 1 的最短路,转移系数只与 i 有关可以预处理,bfs 跑最短路即可实现 O(1) 转移。

容易发现 j 这一维上界是一个自然根号,因为每个严格小于会贡献后面的长度个,第 i 个后面至少产生 j-i+1 的贡献。

所以状态数和复杂度总共是 O(nk\sqrt{k}) 的,卡空间的话把转移前驱压一下,可以压成一个 int,不要拿结构体存(左对齐,会按空间最大的类型算),但其实精细实现结构体也能存下。

拿到代码了,发现被卡常了,bfs 到答案后直接退出就可以快一倍,下面是修改后的代码。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
const int M=3005;
const int N=205;
const int mod=998244353;
const int inf=1e9+7;
struct node{
    bool o;
    short x;
    char y;
    short s;
    node(int op=0,int X=0,int Y=0,int S=0){
        o=(bool)op,x=(short)X,y=(char)Y,s=(short)S;
    }
    //int o,x,y,s;
};
int tid,n,m,nx[N][2],fl[N],bj[N][N],tr[N][2],tt[N][2];
short dis[2][N][83][M];
node pre[2][N][83][M];
bool vis[2][N][83][M];
char s[N],ss[M*5],cc[2][N][83][M];
queue<node> q;
void sol(){
    scanf("%d%d",&n,&m);
    scanf("%s",s+1);
    if(n==1&&s[1]=='0'){
        if(m==0) puts("0");
        else puts("Impossible");
        return;
    }
    int tot=0;
    for(int i=1;i<=n;i++) for(int j=1;j<=n;j++) bj[i][j]=-1;
    for(int i=1;i<=n;i++){
        int op=0;
        for(int j=1;j<=n-i+1;j++){
            if(op) tot++;
            else{
                if(s[i+j-1]>s[j]) break;
                if(s[i+j-1]<s[j]) op=1;
                if(i!=1||j!=n) tot++;
            } 
            bj[i][i+j-1]=op;
        }
    }
    //printf("bj:%d\n",bj[2][3]);
    for(int i=0;i<=n;i++){
        for(int o=0;o<2;o++){
            char ch='0'+o;
            tr[i][o]=0;
            for(int j=1;j<=i;j++){
                if(bj[j][i]==1) tr[i][o]++;
                if(bj[j][i]==0&&s[i-j+2]>ch) tr[i][o]++;
                if(bj[j][i]==0&&s[i-j+2]==ch&&i-j+2<n) tr[i][o]++;
            }
            if(s[1]>ch) tr[i][o]++;
            if(s[1]==ch&&n>1) tr[i][o]++;
            //printf("tr:%d %d %d\n",i,o,tr[i][o]);
        }
    }
    if(tot>m){
        puts("Impossible");
        return;
    }
    fl[1]=0;
    for(int i=2;i<=n;i++){
        int j=fl[i-1];
        while(j&&s[j+1]!=s[i]) j=fl[j];
        fl[i]=j+(s[j+1]==s[i]?1:0);
        //printf("fail:%d %d\n",i,fl[i]);
    }
    for(int i=1;i<=n+1;i++){
        nx[i-1][0]=nx[i-1][1]=0;
        for(int o=0;o<2;o++){
            if(s[i]-'0'==o) nx[i-1][o]=i;
            else nx[i-1][o]=nx[fl[i-1]][o];
        }
    }
    for(int x=0;x<=n;x++){
        for(int o=0;o<2;o++){
            int tx=nx[x][o];
            char ch='0'+o;
            int ty=0;
            if(tx>x) tt[x][o]=0;
            else if(tx==0){
                for(int i=1;i<=x;i++){
                    if(bj[i][x]==1) ty++;
                    if(bj[i][x]==0&&s[x-i+2]>ch) ty++;
                }
                if(s[1]>ch) ty++;
            }else{
                for(int i=1;i<=x+1-tx;i++){
                    if(bj[i][x]==1) ty++;
                    if(bj[i][x]==0&&s[x-i+2]>ch) ty++;
                }    
            }
            tt[x][o]=ty;
        }
    }
    //for(int i=0;i<=n;i++) for(int o=0;o<2;o++) printf("nx:%d %d %d\n",i,o,nx[i][o]);
    memset(dis,0x3f,sizeof(dis));
    memset(vis,0,sizeof(vis));
    //return;
    while(!q.empty()) q.pop();
    dis[0][0][0][0]=0,vis[0][0][0][0]=1,q.push({0,0,0,0});
    while(!q.empty()){
        node ns=q.front();q.pop();
        int op=ns.o,x=ns.x,y=ns.y,k=ns.s,d=dis[op][x][y][k];
        if(op==1&&k==m) break;
        for(int o=0;o<2;o++){
            char ch='0'+o;
            int to=op,tx=nx[x][o],ty=y,tk=k;
            tk+=y+tr[x][o],ty+=tt[x][o];
            if(tx==n) to=1;
            if(tk>m) continue;
            if(vis[to][tx][ty][tk]) continue;
            vis[to][tx][ty][tk]=1,dis[to][tx][ty][tk]=d+1;
            cc[to][tx][ty][tk]=ch,pre[to][tx][ty][tk]=ns;
            q.push({to,tx,ty,tk});
        }
    }
    int ans=10000;node nw={0,0,0,0};
    for(int i=0;i<=n;i++) for(int j=0;j<=80;j++){
        if(dis[1][i][j][m]<ans) ans=dis[1][i][j][m],nw={1,i,j,m};
    } 
    if(ans==10000) puts("Impossible");
    else{
        int i=ans;
        while(i){
            int op=nw.o,x=nw.x,y=nw.y,k=nw.s;
            //printf("i:%d %d %d %d %d\n",i,op,x,y,k);
            ss[i]=cc[op][x][y][k],nw=pre[op][x][y][k],i--;
        }
        for(int i=1;i<=ans;i++) putchar(ss[i]);
        puts("");
        int sum=0;
        for(int i=1;i<=ans;i++){
            int op=0;
            for(int j=1;j<=ans-i+1;j++){
                if(op) sum++;
                else{
                    if(j>n) break;
                    if(ss[i+j-1]>s[j]) break;
                    if(ss[i+j-1]<s[j]) op=1;
                    if(j<n||op) sum++;
                } 
            }
        }
        //printf("sum:%d %d\n",sum,m);
    }
}
int main(){
    //freopen("string.in","r",stdin);
    //freopen("string.out","w",stdout);
    int T=1;
    scanf("%d%d",&tid,&T);
    while(T--) sol();
    return 0;
}