一道后缀数组的题目。
初学的蒟蒻也只能抄抄题解这样子。
最暴力的想法:在两个串中枚举极长子串
可以只考虑两原串某后缀的所有前缀
从而不重不漏找到所有子串
考虑到后缀的前缀,自然想到后缀数组
把两个串接在一起,中间隔一个其他字符(除小写字母外)
求出任意两后缀的LCP的值
可以用het数组和ST表实现,但依然不足以AC
进一步考虑:
由于het表示的是排名相邻的两个后缀LCP的长度,
所以任意两个后缀的LCP长度为按字典序排序后它们中间最小的het
也就是说排序后,一个后缀越往后数LCP的长度越小
这样,我们就可以用单调栈维护这个最小值
单调栈中有两个值:一个是het值,一个是位置i
i在这里充当一个系数(因为弹出的元素实际上还会放回去)
分A串的子串在前和B串的子串在前两种情况进行讨论
两种情况答案相加即可
```cpp
#include<cstdio>
#include<cstring>
#include<algorithm>
#define reg register
using namespace std;
typedef long long ll;
const int N=4e5+5;
char c[N];
int n,l1,l2,top,sum[N];
pair<int,ll>stack[N];
struct HOU
{
int n,m,a[N],top[N],rank[N],sa[N],tax[N],het[N];
inline void qsort()
{
memset(tax,0,sizeof(tax));
for (reg int i=1;i<=n;i++) ++tax[rank[i]];
for (reg int i=1;i<=m;i++) tax[i]+=tax[i-1];
for (reg int i=n;i>=1;i--)
sa[tax[rank[top[i]]]--]=top[i];
}
inline void getSA()
{
for (reg int i=1;i<=n;i++) rank[i]=a[i],top[i]=i;
m=127; qsort();
for (reg int w=1,p=0;p<n;m=p,w<<=1)
{
p=0;
for (reg int i=1;i<=w;i++) top[++p]=n-w+i;
for (reg int i=1;i<=n;i++)
if (sa[i]>w) top[++p]=sa[i]-w;
qsort(); swap(rank,top);
rank[sa[1]]=p=1;
for (reg int i=2;i<=n;i++)
if (top[sa[i-1]]==top[sa[i]]&&top[sa[i-1]+w]==top[sa[i]+w])
rank[sa[i]]=p;
else rank[sa[i]]=++p;
}
int k=0;
for (reg int i=1;i<=n;i++)
{
k=(k?k-1:0);
while (c[i+k]==c[sa[rank[i]-1]+k]) ++k;
het[rank[i]]=k;
}
}
}R;
inline ll getans()//两个后缀的LCP(最长公共前缀长度)为按照字典序排序后它们之间最小的het
{
ll ans=0;
stack[0]=make_pair(1,0);
for (reg int i=1;i<=R.n;i++)
sum[i]=sum[i-1]+(R.sa[i]<=l1);
for (reg int i=1;i<=R.n;i++)
{
while (top&&R.het[stack[top].first]>R.het[i]) --top;
stack[++top]=make_pair(i,1ll*(sum[i-1]-sum[stack[top-1].first-1])*R.het[i]+stack[top-1].second);
if (R.sa[i]>l1+1) ans+=stack[top].second;
}
top=0;
for (reg int i=1;i<=R.n;i++)
sum[i]=sum[i-1]+(R.sa[i]>l1+1);
for (reg int i=1;i<=R.n;i++)
{
while (top&&R.het[stack[top].first]>R.het[i]) --top;
stack[++top]=make_pair(i,1ll*(sum[i-1]-sum[stack[top-1].first-1])*R.het[i]+stack[top-1].second);
if (R.sa[i]<=l1) ans+=stack[top].second;
}
return ans;
}
int main()
{
scanf("%s",c+1); l1=strlen(c+1);//第一个串
scanf("%s",c+l1+2); c[l1+1]='z'+1;//把第二个串接在第一个串后面,中间隔一个其他字符
R.n=strlen(c+1);
for (reg int i=1;i<=R.n;i++) R.a[i]=c[i]-'a'+1;
R.getSA(); printf("%lld\n",getans());
return 0;
}
```