P4324 [JSOI2016]扭动的回文串 (manacher+二分+hash)

· · 个人记录

题面

首先前两种情况跑 manacher 就行。

考虑重点的第三种情况:

假设当前的总回文串的回文中心在 A 串,那么发现最大化 A 串的回文长度肯定是最优的。

然后二分 + hash 处理 A 的回文串左边和 B 串的部分。

回文中心在 B 串的情况类似。

复杂度 O(N \log N)

#include<iostream>
#include<cstdio>
#include<cstdio>
#include<cstring>
#define ull unsigned long long
using namespace std;
const int MAXN = 2e5 + 50;
const int base = 233;
ull hsha1[MAXN], hsha2[MAXN], hshb1[MAXN], hshb2[MAXN], pown[MAXN];
int len1[MAXN], len2[MAXN];
char a[MAXN], b[MAXN], s[MAXN];
int N, ans;
void manacher(char *ts, int *len)
{
    s[0] = '~';
    s[2 * N + 1] = '$';
    for (int i = 1; i <= N; ++i)
    {
        s[2 * i - 1] = '$';
        s[2 * i] = ts[i];
    }
    int maxr = 0;
    int id = 0;
    for (int i = 1; i <= 2 * N; ++i)
    {
        if (i < maxr)
            len[i] = min(maxr - i, len[(id << 1) - i]);
        else
            len[i] = 1;
        while (s[i - len[i]] == s[i + len[i]])
            len[i]++;
        if (i + len[i] > maxr)
        {
            maxr = i + len[i];
            id = i;
        }
        // cout << i << ":" << len[i] << endl;
        int lid = i - len[i] + 1;
        int rid = i + len[i] - 1;
        lid += lid & 1;
        rid -= rid & 1;
        if (lid > rid)
            continue;
        lid >>= 1;
        rid >>= 1;
        ans = max(ans, rid - lid + 1);
    }
}
ull gethsh1(ull *hsh, int l, int r)
{
    return hsh[r] - hsh[l - 1] * pown[r - l + 1];
}
ull gethsh2(ull *hsh, int l, int r)
{
    return hsh[l] - hsh[r + 1] * pown[r - l + 1];
}
bool check1(int len, int l, int r)
{
    return gethsh2(hsha2, l - len + 1, l) == gethsh1(hshb1, r, r + len - 1);
}
bool check2(int len, int l, int r)
{
    return gethsh2(hsha2, l - len + 1, l) == gethsh1(hshb1, r, r + len - 1);
}
void solve()
{
    for (int i = 1; i <= 2 * N; ++i)
    {
        int lid = i - len1[i] + 1;
        int rid = i + len1[i] - 1;
        lid += lid & 1;
        rid -= rid & 1;
        lid >>= 1;
        rid >>= 1;
        // cout << lid << " " << rid << endl;
        int res = rid - lid + 1;
        int l = 0, r = N;
        int len = 0;
        while (l <= r)
        {
            int mid = (l + r) >> 1;
            if (check1(mid, lid - 1, rid))
            {
                l = mid + 1;
                len = mid;
            }
            else
                r = mid - 1;
        }
        ans = max(ans, res + len * 2);
    }
    for (int i = 1; i <= 2 * N; ++i)
    {
        int lid = i - len2[i] + 1;
        int rid = i + len2[i] - 1;
        lid += lid & 1;
        rid -= rid & 1;
        lid >>= 1;
        rid >>= 1;
        // cout << lid << " " << rid << endl;
        int res = rid - lid + 1;
        int l = 0, r = N;
        int len = 0;
        while (l <= r)
        {
            int mid = (l + r) >> 1;
            if (check2(mid, lid, rid + 1))
            {
                l = mid + 1;
                len = mid;
            }
            else
                r = mid - 1;
        }
        ans = max(ans, res + len * 2);
    }
}   
int main()
{
    scanf("%d", &N);
    pown[0] = 1;
    for (int i = 1; i <= N; ++i)
        pown[i] = pown[i - 1] * base;
    scanf("%s", a + 1);
    scanf("%s", b + 1);
    for (int i = 1; i <= N; ++i)
    {
        hsha1[i] = hsha1[i - 1] * base + a[i];
        hshb1[i] = hshb1[i - 1] * base + b[i];
    }
    for (int i = N; i >= 1; --i)
    {
        hsha2[i] = hsha2[i + 1] * base + a[i];
        hshb2[i] = hshb2[i + 1] * base + b[i];
    }
    manacher(a, len1);
    manacher(b, len2);
    solve();
    printf("%d\n", ans);
    return 0;
}