题解:P2779 [AHOI2016初中组] 黑白序列
limingyuan333 · · 题解
题意:
你有
思路:
首先考虑我们实际上存在一个
#include<bits/stdc++.h>
using namespace std;
const int mod=1e9+7;
const int MAXN=5e5+10;
int dp[MAXN];
string s;
void add(int &x,int y){
x+=y;if(x>=mod) x-=mod;
}
int check(int l,int r){
if((r-l)%2==0) return 0;
int len=r-l+1;len/=2;int f1=1;
for(int i=l;i<=l+len-1;i++) f1&=(s[i]!='W');
for(int i=l+len;i<=r;i++) f1&=(s[i]!='B');
return f1;
}
signed main(){
ios::sync_with_stdio(false);
cin.tie(NULL),cout.tie(NULL);
int n;cin>>n;int sum=0;
cin>>s;s=' '+s;dp[0]=1;
for(int i=1;i<=n;i++){
for(int j=1;j<i;j++){
add(dp[i],1ll*dp[j-1]*check(j,i)%mod);
}
}cout<<dp[n];
return 0;
}
我们发现这个代码的判断是否可以染色有点慢,考虑优化一下,发现本质上是
#include<bits/stdc++.h>
using namespace std;
const int mod=1e9+7;
const int MAXN=5e5+10;
int dp[MAXN];
string s;
void add(int &x,int y){
x+=y;if(x>=mod) x-=mod;
}
int R0[MAXN],R1[MAXN];
signed main(){
ios::sync_with_stdio(false);
cin.tie(NULL),cout.tie(NULL);
int n;cin>>n;int sum=0;
cin>>s;s=' '+s;dp[0]=1;int lst0=n+1,lst1=n+1;
for(int i=n;i>=1;i--){
if(s[i]=='B') lst0=i;
if(s[i]=='W') lst1=i;
R0[i]=lst0,R1[i]=lst1;
}
for(int i=1;i<=n;i++){
for(int j=1;j<i;j++){
int len=(i-j+1)/2;
if((i-j)%2==1&&R1[j]>=j+len&&R0[j+len]>i){
add(dp[i],dp[j-1]);
}
}
}cout<<dp[n];
return 0;
}
发现上面代码有很明显的单调性,考虑答案贡献区间,首先满足
#include<bits/stdc++.h>
using namespace std;
const int mod=1e9+7;
const int MAXN=5e5+10;
int dp[MAXN];
string s;
void add(int &x,int y){
x+=y;if(x>=mod) x-=mod;
}
int R0[MAXN],R1[MAXN],n;
vector<int>vec[MAXN];
struct BIT{
int t[MAXN];
int lowbit(int x){
return x&-x;
}
void add(int x,int y){
for(int i=x;i<=n;i+=lowbit(i)) (t[i]+=y)%=mod;
}
int sum(int x){
int res=0;
if(x<=0) return 0;
for(int i=x;i;i-=lowbit(i)) (res+=t[i])%=mod;
return res;
}
int qry(int l,int r){
return (sum(r)-sum(l-1)+mod)%mod;
}
}bit[2];
signed main(){
ios::sync_with_stdio(false);
cin.tie(NULL),cout.tie(NULL);
cin>>n;int sum=0;
cin>>s;s=' '+s;dp[0]=1;int lst0=n+1,lst1=n+1;
for(int i=n;i>=1;i--){
if(s[i]=='B') lst0=i;
if(s[i]=='W') lst1=i;
R0[i]=lst0,R1[i]=lst1;
}lst0=0;
for(int i=1;i<=n;i++){
if(s[i]=='B') lst0=i;
int len=i-lst0;
int t=i-2*len+1;
dp[i]=bit[(i+1)%2].qry(t,i-1);
if(R1[i]!=i) bit[i%2].add(i,dp[i-1]);
if(i+(R1[i]-i)*2-1<=n) vec[i+(R1[i]-i)*2-1].push_back(i);
for(auto x:vec[i]) bit[x%2].add(x,mod-dp[x-1]);
}cout<<dp[n];
return 0;
}