题解 P3649 【[APIO2014]回文串】

· · 个人记录

这篇博客介绍的是本题的SAM解法。

张天扬有一篇文章叫做《APOI2014 回文串 解题报告》,里面给出了SAM解决这个题的解法,很可惜目前网上并没有法子直接找到这篇文章。幸好本题在洛谷里还有一位Treeloveswater大佬给出了SAM做法,不过很可惜他图床挂了,并且无法联系到他。我只能拿着他的代码找机房的一位好兄弟帮忙看了看,然后他给我讲了一下这个做法。我有一些观点与Treeloveswater大佬并不相同,如有错误,敬请大家指正。

先说一个显然的事情,一个串倒过来和它正着长得一样,它就是回文串。

然后我分算法流程,正确性,复杂度三部分来说明一下这个做法。

算法流程第一步就是建立SAM,过程中除了基本的那些信息,我们对每个节点额外维护一下maxposdp, maxpos代表这个节点所维护的endpos集合的最大值,dp代表这个节点代表字符串的出现次数,具体维护细节请参考代码。

然后我们用反串在SAM上跑匹配,这个过程中记录匹配长度l与到了目前位置节点now,具体维护细节请参考代码,如果我们遇到一个节点,满足maxpos[now] - l < i, 那么我们就对它与它所有祖先节点进行遍历,如果满足maxpos[p] - len[p] < i && i <= maxpos[p],那么我们统计这个点对答案的贡献。

正确性我分成两部分,第一部分是为什么最大存在值会被统计,第二部分是为什么所有统计的答案都是存在值。

我们考虑这个最长回文串x,所在SAM的节点为p,考虑pmaxpos是否等于其右端点,如果不等于,那么说明有一个和它长得一模一样的回文串,我们就不考虑它了;如果等于,假设跑到正数第i个点,那么当i等于该回文串的右端点时,假设此时到了q点,那么显然有pq的在SAM上的祖先,此时显然满足l >= maxpos[p] - i,因为l一定比p的长度大,而p的长度即为maxpos[p] - i,又显然满足i<=st[p].maxpos,毕竟i是匹配串的左端点, 所以最大存在值会一定被统计。

为了保证所有统计的答案都是存在值,我们需要引入一个vis数组,每个点被统计完答案,我们就给它打个标记,每次只遍历没有遍历过的点,这样子每个点p只会在i匹配到maxpos时被遍历到。对于所有!vis[p] && l >= maxpos[p] - i && i<=st[p].maxpos的点, 它贡献的答案dp[p] * (maxpos[p] - i + 1)都显然一定是存在值。

复杂度就是建立SAM的复杂度,跑匹配的复杂度,统计每个点对答案贡献的复杂度,都是O(n)

这题空间比较小,#define int long longmle

#include<bits/stdc++.h>
using namespace std;
#define maxn 600005
#define ll long long
#define Fol(i, j, n) for(register int i = j ; i >= n ; --i) 
#define For(i, j, n) for(register int i = j ; i <= n ; ++i) 
ll ans;
int n, m, l, k, T, rt, cnt, now = 1, tot, last;
int a[maxn], dp[maxn], fa[maxn], rk[maxn], sa[maxn], len[maxn], num[maxn], vis[maxn], maxpos[maxn], go[maxn][26];
char s[maxn];

inline void insert(int c, int y){ 
    int np = ++tot, x = last; 
    last = np, dp[np] = 1, maxpos[np] = y, len[np] = len[x] + 1;
    for( ; !go[x][c] && x ; x = fa[x]) go[x][c] = np;
    if(!x) return fa[np] = rt, void();
    int nq = go[x][c];
    if(len[nq] == len[x] + 1) fa[np] = nq;
    else{
        int q = ++tot;
        memcpy(go[q], go[nq], sizeof(go[q]));
        len[q] = len[x] + 1, fa[q] = fa[nq], fa[nq] = fa[np] = q;
        for( ; go[x][c] == nq ; x = fa[x]) go[x][c] = q;
    }
}

int main(){
    scanf("%s", s + 1), n = strlen(s + 1), rt = last = ++tot;
    For(i, 1, n) insert(s[i] - 'a', i);
    For(i, 1, tot) num[len[i]]++;
    For(i, 1, n) num[i] += num[i - 1];
    For(i, 1, tot) a[--num[len[i]]] = i;
    Fol(i, tot, 1) dp[fa[a[i]]] += dp[a[i]], maxpos[fa[a[i]]] = max(maxpos[fa[a[i]]], maxpos[a[i]]);
    Fol(i, n, 1){
        for( ; now && !go[now][s[i] - 'a'] ; now = fa[now], l = len[now]);
        if(go[now][s[i] - 'a']) ++l, now = go[now][s[i] - 'a'];
        if(maxpos[now] - l < i){
            if(i <= maxpos[now]) ans = max(ans, 1ll * dp[now] * (maxpos[now] - i + 1)); 
            for(int p = fa[now] ; p && !vis[p] ;  p = fa[p]){
                vis[p] = 1;
                if(maxpos[p] - len[p] < i && i <= maxpos[p]) ans = max(ans, 1ll * dp[p] * (maxpos[p] - i + 1)); 
            }
        } 
    }
    printf("%lld", ans);
    return 0;
}