给我唐完了。
求字符集为
\{0,1,2\} 的两个字符串a,b 的最长公共不下降子序列长度。
本质上是将序列
为了方便表述,设
注意到枚举
枚举
第一个
考虑钦定
简写一下:
前两个条件是单调的,从小到大枚举
观察到
#include<bits/stdc++.h>
#define il inline
using namespace std;
const int maxn=5000010;
const int inf=1<<30;
il int read(){
int x=0;
char c=getchar();
for(;!(c>='0'&&c<='9');c=getchar());
for(;c>='0'&&c<='9';c=getchar())
x=(x<<1)+(x<<3)+c-'0';
return x;
}
int Test,n,Mn,Mx,N;
int ying0[maxn],ying2[maxn];
char a[maxn],b[maxn];
int w1[maxn],w2[maxn];
int h0[maxn],h2[maxn];
int S[maxn][3],T[maxn][3];
int Tree[maxn<<2];
void Add(int k,int x){for(k=N-k+1;k<=N;k+=k&-k)Tree[k]=max(Tree[k],x);}
int Que(int k,int Mx=-inf){for(k=N-k+1;k;k-=k&-k)Mx=max(Mx,Tree[k]);return Mx;}
int calc(){
memset(Tree,128,sizeof(Tree));
for(int i=1;i<=n;i++)
for(int j=0;j<3;j++)
S[i][j]=S[i-1][j]+(a[i]==j);
for(int i=1;i<=n;i++)
for(int j=0;j<3;j++)
T[i][j]=T[i-1][j]+(b[i]==j);
ying0[0]=0;
for(int i=1,j=1,tot1=0,tot2=0;i<=n;i++){
tot1+=(a[i]==0);
while(j<=n&&tot2<tot1)
tot2+=(b[j]==0),j++;
ying0[i]=j-1;
}
ying2[n]=n;
for(int i=n,j=n,tot1=0,tot2=0;i;i--){
tot1+=(a[i]==2);
while(j&&tot2<tot1)
tot2+=(b[j]==2),j--;
ying2[i-1]=j;
}
Mn=inf,Mx=-inf;
for(int i=0;i<=n;i++){
h2[i]=S[i][1]-T[ying2[i]][1];
h0[i]=S[i][1]-T[ying0[i]][1];
w1[i]=min(S[i][0],T[ying0[i]][0])-S[i][1];
w2[i]=S[i][1]+min(S[n][2]-S[i][2],T[n][2]-T[ying2[i]][2]);
Mn=min(Mn,min(h2[i],h0[i]));
Mx=max(Mx,max(h2[i],h0[i]));
}N=Mx-Mn+1;
int ans=0;
for(int i=0,r=0;i<=n;i++){
while(r<=i&&ying0[r]<=ying2[i]) Add(h0[r]-Mn+1,w1[r]),r++;
ans=max(ans,Que(h2[i]-Mn+1)+w2[i]);
}
return ans;
}
int main(){
// freopen("b.in","r",stdin);
// freopen("b_spj.out","w",stdout);
Test=read();
while(Test--){
scanf("%s%s",a+1,b+1);
n=strlen(a+1);
for(int i=1;i<=n;i++)
a[i]-='0',b[i]-='0';
int ans1=calc();
for(int i=1;i<=n;i++)
swap(a[i],b[i]);
int ans2=calc();
printf("%d\n",max(ans1,ans2));
}
return 0;
}