[题解]P16827 [AFOI 2025] D.谐音替换

· · 题解

更好的阅读体验

这下真的是连退役选手都会做的题了。

思路

要把 T 划分成三个串,也就是有两个分界点,枚举这两个点 x,y 哈希判断 [1,x],(x,y),[y,len] 是否在 S_{1 \sim n} 的前缀/后缀中出现没有即可。复杂度 \Theta(\sum L^2)

接下来的一个基本思路就是考虑能不能只枚举一个值。若枚举 x[1,x] 是否合法是容易得到的,若枚举 y[y,len] 是否合法也是容易得到的。不妨记 vap_i,vas_i = 0/1 表示 x = i 的前缀和 y = i 的后缀是否合法。

这给了我们一定的启示,我们需要主动考虑中间段的合法条件,此时只需注意到一个事实就做完了:前缀的前缀依旧是前缀,后缀的后缀依旧是后缀。

对于前者,枚举 x,则使中间段为前缀的 y 必须满足 y > x + pmax_{x + 1},结尾段合法的条件是 vas_y = 1,查 vas 的一个后缀和即可;对于后者,枚举 y,则使中间段为后缀的 x 必须满足 x < y - smax_{y - 1},开头段合法的条件为 vap_x = 1,查 vap 的一个前缀和即可。

注意到中间段同时为前缀和后缀的情况被统计了两次,考虑容斥减掉。使中间段同时为前缀和后缀的 x,y 需要满足 y > x + pmax_{x + 1}x < y - smax_{y - 1},扫描线一下就算出来了。

计算 pmax_i,smax_i 可以直接二分,复杂度 \Theta(\sum L \log L)

Code

#include <bits/stdc++.h>
#define re register
#define int long long

using namespace std;

const int N = 5e5 + 10;
const int base1 = 1129,mod1 = 102401027;
const int base2 = 1249,mod2 = 1000000241;
int n,q;
char s[N];
int pw1[N],hs1[N],pw2[N],hs2[N];
int pmax[N],smax[N],vap[N],vas[N];
vector<int> S[N];
unordered_set<int> stp1,sts1,stp2,sts2;

inline int read(){
    int r = 0,w = 1;
    char c = getchar();
    while (c < '0' || c > '9'){
        if (c == '-') w = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9'){
        r = (r << 3) + (r << 1) + (c ^ 48);
        c = getchar();
    }
    return r * w;
}

inline void init(int n){
    pw1[0] = pw2[0] = 1;
    for (re int i = 1;i <= n;i++) pw1[i] = pw1[i - 1] * base1 % mod1;
    for (re int i = 1;i <= n;i++) pw2[i] = pw2[i - 1] * base2 % mod2;
}

inline int geths1(int l,int r){ return (hs1[r] - hs1[l - 1] * pw1[r - l + 1] % mod1 + mod1) % mod1; }
inline int geths2(int l,int r){ return (hs2[r] - hs2[l - 1] * pw2[r - l + 1] % mod2 + mod2) % mod2; }

struct{
    #define lowbit(x) ((x) & -(x))

    int n,tr[N];

    inline void modify(int x,int k){
        for (re int i = x;i <= n;i += lowbit(i)) tr[i] += k;
    }

    inline int query(int x){
        int sum = 0;
        for (re int i = x;i;i -= lowbit(i)) sum += tr[i];
        return sum;
    }
    inline int query(int l,int r){ return query(r) - query(l - 1); }

    #undef lowbit
}T;

signed main(){
    init(5e5);
    n = read(),q = read();
    for (re int i = 1;i <= n;i++){
        scanf("%s",s + 1);
        int len = strlen(s + 1);
        for (re int j = 1;j <= len;j++) hs1[j] = (hs1[j - 1] * base1 + s[j]) % mod1;
        for (re int j = 1;j <= len;j++) hs2[j] = (hs2[j - 1] * base2 + s[j]) % mod2;
        for (re int j = 1;j <= len;j++) stp1.insert(geths1(1,j)),sts1.insert(geths1(j,len));
        for (re int j = 1;j <= len;j++) stp2.insert(geths2(1,j)),sts2.insert(geths2(j,len));
    }
    while (q--){
        int cnt1 = 0,cnt2 = 0,cnt3 = 0;
        scanf("%s",s + 1);
        int len = strlen(s + 1);
        T.n = len;
        fill(T.tr,T.tr + len + 3,0);
        for (re int i = 1;i <= len;i++) hs1[i] = (hs1[i - 1] * base1 + s[i]) % mod1;
        for (re int i = 1;i <= len;i++) hs2[i] = (hs2[i - 1] * base2 + s[i]) % mod2;
        for (re int i = 1;i <= len;i++){
            int l = 0,r = len - i + 1;
            while (l < r){
                int mid = (l + r + 1) >> 1;
                if (stp1.find(geths1(i,i + mid - 1)) != stp1.end() && stp2.find(geths2(i,i + mid - 1)) != stp2.end()) l = mid;
                else r = mid - 1;
            } pmax[i] = l;
        }
        for (re int i = 1;i <= len;i++){
            int l = 0,r = i;
            while (l < r){
                int mid = (l + r + 1) >> 1;
                if (sts1.find(geths1(i - mid + 1,i)) != sts1.end() && sts2.find(geths2(i - mid + 1,i)) != sts2.end()) l = mid;
                else r = mid - 1;
            } smax[i] = l;
        }
        for (re int i = 1;i <= len;i++) vap[i] = vap[i - 1] + (i <= pmax[1] || i == smax[i]);
        for (re int i = 1;i <= len;i++) vas[i] = vas[i - 1] + (len - i + 1 == pmax[i] || len - i + 1 <= smax[len]);
        for (re int i = 1;i < len;i++){
            if (vap[i] - vap[i - 1]) cnt1 += (vas[min(i + pmax[i + 1] + 1,len)] - vas[i + 1]);
        }
        for (re int i = 2;i <= len;i++){
            if (vas[i] - vas[i - 1]) cnt2 += (vap[i - 2] - vap[max(1ll,i - smax[i - 1] - 1) - 1]);
        }
        for (re int i = 1;i <= len;i++){
            if (vap[i] - vap[i - 1]) S[min(i + pmax[i + 1] + 1,len)].push_back(i);
        }
        for (re int i = len;i > 1;i--){
            for (int x:S[i]) T.modify(x,1);
            if (vas[i] - vas[i - 1]) cnt3 += T.query(max(1ll,i - smax[i - 1] - 1),i - 2);
        } printf("%lld\n",cnt1 + cnt2 - cnt3);
        for (re int i = 1;i <= len;i++) S[min(i + pmax[i],len)].clear();
    }
    return 0;
}