P3435题解

· · 题解

题解里面怎么这么多kmp。。。。我这个彩币根本不会。。。
首先看到n=1e6,考虑nlogn算法。
发现每个前缀都有一定的“适用范围”。
定义一个前缀的适用范围(l,r)为此前缀为前l个字母至前r个字母的周期。我们只需要求出此,然后按前缀长度从小到大覆盖区间,最后求出1,n的区间和就是答案。
求适用范围可以用字符串hash求,区间推平可以用线段树。 具体见代码。

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define gets(S) fgets(S,sizeof(S),stdin)
#define mid ((l+r)/2)
int n,ans,h[4000100],mod=131,len,l,r,pw[4000100];
struct xy
{
    int x,y;
}p[4000100];
char a[4000100];
inline bool check(int w,int k)
{//1~w:h[i] w+1~k:h[w]*pw[i-w]+h[i-w]
    int t;
    if(k<=w)t=h[k];
    else t=h[w]*pw[k-w]+h[k-w];
    if(h[k]==t)return 1;
    else return 0;
}
int c[4000100],s[4000100];
void pushdown(int k,int l,int r)
{
    if(c[k]!=0)
    {
        c[k*2+1]=c[k*2]=c[k];
        s[k*2]=(mid-l+1)*c[k];
        s[k*2+1]=(r-mid)*c[k];
        c[k]=0;
    }
}
void pushup(int k)
{
    s[k]=s[k*2]+s[k*2+1];
}
void clean(int k,int l,int r,int ll,int rr,int cl)
{
    if(r<ll||rr<l)return;
    if(r<l)return;
    if(ll<=l&&r<=rr)
    {
        s[k]=cl*(r-l+1);
        c[k]=cl;
        return;
    }
    pushdown(k,l,r);
    clean(k*2,l,mid,ll,rr,cl);
    clean(k*2+1,mid+1,r,ll,rr,cl);
    pushup(k);
}
signed main()
{
    scanf("%lld",&n);
    pw[0]=1;
    for(int i=1;i<=n*2+1;++i)pw[i]=pw[i-1]*mod;
    //for(int i=1;i<=n;++i)printf("%lld ",pw[i]);
    //puts("");
    gets(a);
    gets(a);
    for(int i=n-1;i>=0;--i)a[i+1]=a[i];
    for(int i=1;i<=n;++i)h[i]=h[i-1]*mod+a[i]-'a'+1;
    //for(int i=1;i<=n;++i)printf("%lld ",h[i]);
//  puts("");
    for(int i=1;i<=n;++i)
    {
        l=i;
        r=min(i+i,n);
        while(l+1<r)
        {
        //  printf("%lld %lld %lld %lld\n",l,r,mid,i);
            if(check(i,mid))l=mid;
            else r=mid-1;
        }
        if(r==l+1)
        {
            if(check(i,r))l=r;
        }
        if(l>i)
        {
            p[++len]={i,l};
        }
    }
    p[len+1]={INT_MAX,0};
    for(int i=1;i<=len;++i)
    {
    //  printf("%lld %lld\n",p[i].x,p[i].y);
        clean(1,1,n,p[i].x+1,p[i].y,p[i].x);
        /*
        for(int i=1;i<=40;++i)
        {
            printf("%lld ",s[i]);
        }
        puts("");
        for(int i=1;i<=40;++i)
        {
            printf("%lld ",c[i]);
        }
        puts("");*/
    }
    printf("%lld\n",s[1]);
    return 0;
}