题解:P15301 [ROI 2012 Day 2] army 汗国军队

· · 题解

第一次见 DP 套 DP。

题意

求出有多少排列满足给定序列是它的一个 LIS。

Solution

看到数据范围,可以猜到复杂度里有个 2^n3^n。先满足条件 LIS 长度不超过 k。考虑从小往大在序列中插入数,先假设我们已知每个数要插在那个位置,并维护过程中的 LIS。

LIS 可以用 DP 维护,这里我们维护数组 g_i,表示前缀 [1,i] 中,LIS 的长度。当我们在位置 x 后插入数时,数组 g 的变化如下(注意由于是插入,数组下标也会变化):

g_i\gets \begin{cases} g_i &i\le x\\ g_{i-1}+1 & i=x+1\\ \max(g_{i-1},g_x) & i>x+1\end{cases}

容易发现任意时刻,数组 g 的差分数组为 01 数组,因此可以用二进制数来刻画数组 g。考虑将上述插入带来的修改刻画到 01 串上:

原串拆成两部分:[1,x][x+1,n],在两部分之间插入一个 1,并删掉 [x+1,n] 的第一个 1。以上操作可以用位运算 O(1) 实现。

这样我们就能用 DP 统计了,定义新的 DP f_{i,mask} 表示填了 i 个数,g 的差分数组为 mask 的方案数。由于还需满足第二个条件,所以再加一维 j,表示最后一个特殊数的位置在 j,填的过程中保证特殊数填的位置递增即可。

注意 f_{i,j,mask} 的状态数其实为 O(n2^n),因为 n(2^0+2^1+\cdots +2^n)=n2^{n+1}。转移需枚举插入的位置,因此总时间复杂度为 O(n^22^n)

代码

#include <iostream>
#include <cstdio>
using namespace std;
int read()
{
    char c=getchar();
    int f=1,x=0;
    while(c<'0'||c>'9')
    {
        if(c=='-') f=-1;
        c=getchar();
    }
    while(c>='0'&&c<='9')
    {
        x=(x<<1)+(x<<3)+(c^'0');
        c=getchar();
    }
    return x*f;
}
void print(int x)
{
    if(x<0)
    {
        putchar('-');
        x=-x;
    }
    if(x>9) print(x/10);
    putchar(x%10+'0');
}
const int N=20;
int n,m,ans,f[N][N][1<<15];
bool a[N];
int main()
{
    n=read();
    m=read();
    for(int i=1;i<=m;i++) a[read()]=true;
    f[0][0][0]=1;
    for(int i=1;i<=n;i++)
    {
        for(int j=0;j<i;j++)
        {
            for(int k=0;k<(1<<(i-1));k++)
            {
                int val=f[i-1][j][k];
                if(!val) continue;
                for(int x=0;x<i;x++)
                {
                    if(a[i]&&x<j) continue;
                    int r=(k>>x),l=(k^(r<<x));
                    r^=(r&(-r));
                    r=(r<<1|1);
                    int y=j,s=(l|(r<<x));
                    if(a[i]) y=x+1;
                    else y+=(x<j);
                    f[i][y][s]+=val;
                }
            }
        }
    }
    for(int i=0;i<=n;i++)
    {
        for(int j=0;j<(1<<n);j++)
        {
            int x=0;
            for(int k=0;k<n;k++)
                if((j&(1<<k))) x++;
            if(x<=m) ans+=f[n][i][j];
        }
    }
    print(ans);
    return 0;
}