题解:P12917 [POI 2021/2022 R3] 小矮人派对 2 / Impreza krasnali 2

· · 题解

竞选全场最难写的题解,我应该是第一个用这个方法的人。

首先我们可以非常轻松的写出一个暴力代码。成功获得 12pts。

然后我们会看到附加样例里面写的:

该样例满足 n=100000,h_i=i,答案为 F_{n+1}\bmod 10^9+7,其中 F_i 是第 i 个斐波那契数。

这不禁让我们思考:如果每个 i 都只出现一次,但是又不保证 h_i=i 呢?

然后我用暴力代码跑了几组,发现这个结论依然成立!

于是我们就得到了一个没经过证明的结论:在一段长度为 n 的区间中,如果每个数字都只出现过一次,那么这一段的方案数就是 F_n,其中 F_i 是斐波那契数列的第 n 项。现在我们来证明一下这个结论:

证明:设 f(n) 表示一段长度为 n 的区间,且这段区间中的每一个数都只出现一次时的可行方案数,那么我们只需要证明这个 f 满足斐波那契数列的性质就行。

显然当 n=1 时,f(1)=1,当 n=2 时,f(2)=2

则当 n\ge 3 时,考虑最后一个数,此时它可以选择描述它自己,那么方案数就是 1\times f(n-1),同样也可以选择和前一个互相描述对方,那么方案数就是 1\times f(n-2),所以 f(n)=f(n-1)+f(n-2),这显然就是斐波那契数列的定义啊。

没看明白证明的看一下图:

如果把上面看做题目给的数列,下面看做我们构造出来的数列,那么最后一个位置要么直接描述自己,要么和前一个位置互相描述对方,两种方案的方案数分别是 f(n-1)f(n-2)。所以可以得到 f(n)=f(n-1)+f(n-2)

然后我们考虑怎么用这个结论来计算答案。因为这个结论针对的是只出现过一次的数,所以我们考虑所有出现过多次的数。

我们很容易看出:如果一个数出现了三次以上,那么显然没有一种构造方案成立;如果一个数出现了三次,并且三次出现的位置都是相邻的,那么我们必须把这个数放在中间的那个位置;如果一个数出现了两次,那么考虑这两次的位置,如果位置相邻,则两边都可以放;如果两个位置中间隔了一个空位,那么就只能放在中间。具体见图:

上面的方框代表的是输入进来的多个相同的数的位置,下面则是我们构造的数列。那么显然只有这四种情况。

而在两个位置相邻的出现了不止一次的数字中间,所有的数字必然都只出现了一次,所以我们可以用上面的结论算出中间的贡献。

但是现在有个问题:在我们上面提出的四种情况之中,第二种和第三种显然合在一起是一种,但是这一种最终可以放的位置有两个,而这两种放的方法会导致中间可用的格子数不同,算出来的方案数也就不同。因此这里我们使用 DP:设 f_{i,0/1/2} 表示考虑到第 i 位,且这一位的数字出现了多次,不会对左右两边造成影响或让左边多出来一个格子、右边少一个格子或让右边多出来一个格子,左边少一个格子的方案数是多少。

然后这个转移很容易想到。注意第四种情况中间还有一个数字没有放,所以也会对左边或右边的长度造成影响。

现在最后一个问题:转移时的贡献到底是多少?如果在转移的时候,中间的数字个数和没放数字的位置个数是相等的,那么就直接一一对应,贡献就是 F_{len},其中 len 是中间的数字个数,F 是预处理的斐波那契数列。如果能放数字的位置比中间的数字个数多 1,那么我们考虑其中有 i 个数字是被选定直接构造的,剩下的数字都要平移一格在构造到数列上,那么此时的贡献是 \sum_{i=0}^{len}F_i,这个可以用前缀和预处理。如果如果能放数字的位置比中间的数字个数多 2,那么我们考虑中间有 i 个 数字是被选定直接构造的,左边的数字全往左平移一格,右边的数字全往右平移一格,那么此时的贡献就是 \sum_{i=0}^{len}F_i\times(len-i+1),然后拆式子得到 (len+1)\times\sum_{i=0}^{len}F_i-\sum_{i=0}^{len}F_i\times i,再分别用前缀和维护一下,然后就可以 O(1) 转移了!

总时间复杂度:O(n)

注:这种解法的难点主要在于处理中间到底有多少个数,因为不同的情况之间搭配会有不同的结果,所以需要大量分类讨论。

代码:

#include<bits/stdc++.h>
#define int long long
#define code using
#define by namespace
#define plh std
code by plh;
namespace fastio
{
    inline int read()
    {
        int z=0,f=1;
        char c=getchar();
        if(c==EOF)
        {
            exit(0);
        }
        while(c<'0'||c>'9')
        {
            if(c==EOF)
            {
                exit(0);
            }
            if(c=='-')
            {
                f=-1;
            }
            c=getchar();
        }
        while(c>='0'&&c<='9')
        {
            z=z*10+c-'0';
            c=getchar();
        }
        return z*f;
    }
    inline void write(int x)
    {
        if(x<0)
        {
            putchar('-');
            x=-x;
        }
        static int top=0,stk[106];
        while(x)
        {
            stk[++top]=x%10;
            x/=10;
        }
        if(!top)
        {
            stk[++top]=0;
        }
        while(top)
        {
            putchar(char(stk[top--]+'0'));
        }
    }
    inline void write(string s)
    {
        for(auto i:s)
        {
            putchar(i);
        }
    }
}
using namespace fastio;
const int mod=1e9+7;
int n,a[100006],dp[100006],s1[100006],s2[100006],f[100006][3];
vector<int>v[100006];
void solve(int la,int x)
{
    if(la==0)
    {
        if(a[x]==1)
        {
            f[x][0]=s1[x-la-2];
        }
        else if(a[x]==2)
        {
            f[x][1]=s1[x-la-1];
            f[x][2]=dp[x-la-1];
        }
        else
        {
            f[x][1]=s1[x-la-2];
            f[x][2]=dp[x-la-2];
        }
        return;
    }
    if(a[x]==1)
    {
        if(a[la]==1)
        {
            f[x][0]=(f[x][0]+f[la][0]*((s1[x-la-3]*(x-la-2)%mod-s2[x-la-3]+mod)%mod)%mod)%mod;
        }
        else
        {
            f[x][0]=(f[x][0]+f[la][2]*((s1[x-la-3]*(x-la-2)%mod-s2[x-la-3]+mod)%mod)%mod+f[la][1]*s1[x-la-3]%mod)%mod;
        }
    }
    else if(a[x]==2)
    {
        if(a[la]==1)
        {
            f[x][1]=(f[x][1]+f[la][0]*((s1[x-la-2]*(x-la-1)%mod-s2[x-la-2]+mod)%mod)%mod)%mod;
            f[x][2]=(f[x][2]+f[la][0]*s1[x-la-2])%mod;
        }
        else
        {
            f[x][1]=(f[x][1]+f[la][2]*((s1[x-la-2]*(x-la-1)%mod-s2[x-la-2]+mod)%mod)%mod+f[la][1]*s1[x-la-2]%mod)%mod;
            f[x][2]=(f[x][2]+f[la][2]*s1[x-la-2]%mod+f[la][1]*dp[x-la-2]%mod)%mod;
        }
    }
    else
    {
        if(a[la]==1)
        {
            f[x][1]=(f[x][1]+f[la][0]*((s1[x-la-3]*(x-la-2)%mod-s2[x-la-3]+mod)%mod)%mod)%mod;
            f[x][2]=(f[x][2]+f[la][0]*s1[x-la-3])%mod;
        }
        else if(a[la]==2)
        {
            f[x][1]=(f[x][1]+f[la][2]*((s1[x-la-3]*(x-la-2)%mod-s2[x-la-3]+mod)%mod)%mod+f[la][1]*s1[x-la-3]%mod)%mod;
            f[x][2]=(f[x][2]+f[la][2]*s1[x-la-3]%mod+f[la][1]*dp[x-la-3]%mod)%mod;
        }
        else
        {
            if(x-la==1)//这里特判一下:如果第四种情况交在一起,那么只有一种构造方法
            {
                f[x][2]=(f[x][2]+f[la][1])%mod;
            }
            else
            {
                f[x][1]=(f[x][1]+f[la][2]*((s1[x-la-3]*(x-la-2)%mod-s2[x-la-3]+mod)%mod)%mod+f[la][1]*s1[x-la-3]%mod)%mod;
                f[x][2]=(f[x][2]+f[la][2]*s1[x-la-3]%mod+f[la][1]*dp[x-la-3]%mod)%mod;
            }
        }
    }
}
signed main()
{
    n=read();
    for(int i=1,x;i<=n;i++)
    {
        x=read();
        v[x].push_back(i);
    }
    bool fl=true;
    for(int i=1;i<=n;i++)
    {
        if(v[i].size()>3)
        {
            fl=false;
            break;
        }
        else if(v[i].size()==3)
        {
            if(v[i][2]-v[i][0]>2)
            {
                fl=false;
                break;
            }
            else
            {
                a[v[i][0]+1]=1;
            }
        }
        else if(v[i].size()==2)
        {
            if(v[i][1]-v[i][0]>2)
            {
                fl=false;
                break;
            }
            else
            {
                if(v[i][1]-v[i][0]==2)
                {
                    a[v[i][0]+1]=3;
                }
                else
                {
                    a[v[i][0]]=2;
                }
            }
        }
    }
    if(!fl)
    {
        write(0);
        return 0;
    }
    dp[0]=dp[1]=1;
    for(int i=2;i<=n;i++)
    {
        dp[i]=(dp[i-1]+dp[i-2])%mod;
    }
    s1[0]=1;
    for(int i=1;i<=n;i++)
    {
        s1[i]=(s1[i-1]+dp[i])%mod;
        s2[i]=(s2[i-1]+dp[i]*i%mod)%mod;
    }
    int la=0;
    for(int i=1;i<=n;i++)
    {
        if(a[i]==0)
        {
            continue;
        }
        solve(la,i);
        la=i;
    }
    if(a[n]==0)
    {
        if(la==0)
        {
            f[n][0]=dp[n];
        }
        else
        {
            if(a[la]==1)
            {
                f[n][0]=(f[n][0]+f[la][0]*s1[n-la-1])%mod;
            }
            else
            {
                f[n][0]=(f[n][0]+f[la][2]*s1[n-la-1]%mod+f[la][1]*dp[n-la-1]%mod)%mod;
            }
        }
    }
    int ans=f[n][0];
    for(int i=1,cnt=1;i<=n;i++)
    {
        if(v[i].empty())
        {
            ans=(ans*cnt)%mod;
            cnt++;
        }
    }
    write(ans);
    return 0;
}

感觉讲的好像不是很清楚,如果你有任何疑问,可以发在讨论区里,我会尽快回答你的问题。