[题解]P16827 [AFOI 2025] D.谐音替换
更好的阅读体验
这下真的是连退役选手都会做的题了。
思路
要把
接下来的一个基本思路就是考虑能不能只枚举一个值。若枚举
这给了我们一定的启示,我们需要主动考虑中间段的合法条件,此时只需注意到一个事实就做完了:前缀的前缀依旧是前缀,后缀的后缀依旧是后缀。
- 若中间段
[L,R] 是一个前缀,则一定存在一个k \geq R ,使得[L,k] 是S_{1 \sim n} 中的一个前缀。不妨记pmax_i = k - i + 1 ,其中[i,k] 是S_{1 \sim n} 的一个前缀,而[i,k + 1] 不是。则[L,R] 是一个前缀的条件显然就是R \leq L + pmax_L - 1 。 - 若中间段
[L,R] 是一个后缀的情况同理,不妨记smax_i = i - k + 1 ,其中[k,i] 是一个后缀,而[k - 1,i] 不是,则[L,R] 是一个后缀的条件为L \geq R - smax_R + 1 。
对于前者,枚举
注意到中间段同时为前缀和后缀的情况被统计了两次,考虑容斥减掉。使中间段同时为前缀和后缀的
计算
Code
#include <bits/stdc++.h>
#define re register
#define int long long
using namespace std;
const int N = 5e5 + 10;
const int base1 = 1129,mod1 = 102401027;
const int base2 = 1249,mod2 = 1000000241;
int n,q;
char s[N];
int pw1[N],hs1[N],pw2[N],hs2[N];
int pmax[N],smax[N],vap[N],vas[N];
vector<int> S[N];
unordered_set<int> stp1,sts1,stp2,sts2;
inline int read(){
int r = 0,w = 1;
char c = getchar();
while (c < '0' || c > '9'){
if (c == '-') w = -1;
c = getchar();
}
while (c >= '0' && c <= '9'){
r = (r << 3) + (r << 1) + (c ^ 48);
c = getchar();
}
return r * w;
}
inline void init(int n){
pw1[0] = pw2[0] = 1;
for (re int i = 1;i <= n;i++) pw1[i] = pw1[i - 1] * base1 % mod1;
for (re int i = 1;i <= n;i++) pw2[i] = pw2[i - 1] * base2 % mod2;
}
inline int geths1(int l,int r){ return (hs1[r] - hs1[l - 1] * pw1[r - l + 1] % mod1 + mod1) % mod1; }
inline int geths2(int l,int r){ return (hs2[r] - hs2[l - 1] * pw2[r - l + 1] % mod2 + mod2) % mod2; }
struct{
#define lowbit(x) ((x) & -(x))
int n,tr[N];
inline void modify(int x,int k){
for (re int i = x;i <= n;i += lowbit(i)) tr[i] += k;
}
inline int query(int x){
int sum = 0;
for (re int i = x;i;i -= lowbit(i)) sum += tr[i];
return sum;
}
inline int query(int l,int r){ return query(r) - query(l - 1); }
#undef lowbit
}T;
signed main(){
init(5e5);
n = read(),q = read();
for (re int i = 1;i <= n;i++){
scanf("%s",s + 1);
int len = strlen(s + 1);
for (re int j = 1;j <= len;j++) hs1[j] = (hs1[j - 1] * base1 + s[j]) % mod1;
for (re int j = 1;j <= len;j++) hs2[j] = (hs2[j - 1] * base2 + s[j]) % mod2;
for (re int j = 1;j <= len;j++) stp1.insert(geths1(1,j)),sts1.insert(geths1(j,len));
for (re int j = 1;j <= len;j++) stp2.insert(geths2(1,j)),sts2.insert(geths2(j,len));
}
while (q--){
int cnt1 = 0,cnt2 = 0,cnt3 = 0;
scanf("%s",s + 1);
int len = strlen(s + 1);
T.n = len;
fill(T.tr,T.tr + len + 3,0);
for (re int i = 1;i <= len;i++) hs1[i] = (hs1[i - 1] * base1 + s[i]) % mod1;
for (re int i = 1;i <= len;i++) hs2[i] = (hs2[i - 1] * base2 + s[i]) % mod2;
for (re int i = 1;i <= len;i++){
int l = 0,r = len - i + 1;
while (l < r){
int mid = (l + r + 1) >> 1;
if (stp1.find(geths1(i,i + mid - 1)) != stp1.end() && stp2.find(geths2(i,i + mid - 1)) != stp2.end()) l = mid;
else r = mid - 1;
} pmax[i] = l;
}
for (re int i = 1;i <= len;i++){
int l = 0,r = i;
while (l < r){
int mid = (l + r + 1) >> 1;
if (sts1.find(geths1(i - mid + 1,i)) != sts1.end() && sts2.find(geths2(i - mid + 1,i)) != sts2.end()) l = mid;
else r = mid - 1;
} smax[i] = l;
}
for (re int i = 1;i <= len;i++) vap[i] = vap[i - 1] + (i <= pmax[1] || i == smax[i]);
for (re int i = 1;i <= len;i++) vas[i] = vas[i - 1] + (len - i + 1 == pmax[i] || len - i + 1 <= smax[len]);
for (re int i = 1;i < len;i++){
if (vap[i] - vap[i - 1]) cnt1 += (vas[min(i + pmax[i + 1] + 1,len)] - vas[i + 1]);
}
for (re int i = 2;i <= len;i++){
if (vas[i] - vas[i - 1]) cnt2 += (vap[i - 2] - vap[max(1ll,i - smax[i - 1] - 1) - 1]);
}
for (re int i = 1;i <= len;i++){
if (vap[i] - vap[i - 1]) S[min(i + pmax[i + 1] + 1,len)].push_back(i);
}
for (re int i = len;i > 1;i--){
for (int x:S[i]) T.modify(x,1);
if (vas[i] - vas[i - 1]) cnt3 += T.query(max(1ll,i - smax[i - 1] - 1),i - 2);
} printf("%lld\n",cnt1 + cnt2 - cnt3);
for (re int i = 1;i <= len;i++) S[min(i + pmax[i],len)].clear();
}
return 0;
}