区间DP

· · 算法·理论

DP,动态规划,用于解决无后效性的问题,它把原问题视作若干个重叠的子问题的逐层递进,每个子问题的求解过程都会构成一个“阶段”,在完成一个阶段后,才会执行下一个阶段。

DP没有固定模版,绝对不可以在模版基础上思考转移方程。它很灵活,可以解决多数计数问题和最值问题,以及与其他算法结合。同样,转移方程也是很难想,有时候还要加入数学推导,比如容斥和二项式定理。

本博客来记录一下做过的DP的经验,是一个小总结。会随时更新,写的不对的地方请及时指出,谢谢~Thanks♪(・ω・)ノ

区间DP

只要题中有区间、合并、匹配,就要想到区间DP。它的大致思路为:枚举区间端点、长度,有些题要找区间的断点,考虑将断点左右两部分合并的代价,另外的题直接处理区间内的关系,一般和端点的改动有关

当数据范围较小时,才能考虑,因为时间复杂度至少是 O(n^2)。将问题分为两两合并的形式,可以通过子问题有效处理最终的总问题。

易错点

①枚举断点时不包含区间结尾,因为这样 [i,k] 就和 [i,j] 没有区别了,而且部分题中这会出错

②不是每个题都要枚举断点,从找断点出发思考是错误的

③字符串也可以用区间DP,不过一定注意下标差 1 的问题

例题

1.P1880

合并,确定为区间DP。

设区间为 [i,j],断点为 k,则将区间 [i,k][k+1,j] 合并所需的代价为 a[i]+a[i+1]+...+a[j]

fn_{i,j} 为合并出区间 [i,j]的最小代价,fx 为最大,那么 fn_{i,j}=min(fn_{i,j},fn_{i,k}+fn_{k+1,j}+a[i]+a[i+1]+...+a[j])fxmax

转移方程推出来了,这样就行了吗?显然不是。目前的时间复杂度为 O(n^4)=10^8,有点卡,所以在 O(n) 计算 a[i]+...+a[j] 时,用前缀和预处理,变成 a[j]-a[i-1]

这还不是最大的问题,题目中说在一个圆形操场的四周,也就是说这是一个环,那么下标为 1n 的石子也可以合并,而上述思路中并未考虑。想办法使 a 数组成为一个环,其实很简单,只需要把它本身复制一遍接到后面。注意不能只复制 a_1a_n 后面,因为合并时如果 a_1a_2 合并再和 a_n 合并,就无法相邻(举例)。

这个题初始化很简单,fn 全赋值为最大值,且保证 f_{i,i}0 就可以。

Code:
#include <bits/stdc++.h>
using namespace std;
int a[105],s[205],fn[205][205],fx[205][205];
int main()
{
    memset(fn,0x3f3f3f3f,sizeof(fn));
    int n;
    scanf("%d",&n);
    for(int i = 1;i <= n;i++)
    {
        scanf("%d",&a[i]);
    }
    for(int i = 1;i <= 2 * n;i++)
    {
        s[i] = s[i - 1] + a[i <= n ? i : i % n];
        fn[i][i] = 0;
    }
    for(int l = 1;l <= 2 * n;l++)
    {
        for(int i = 1;i <= 2 * n - l + 1;i++)
        {
            int j = i + l - 1;
            for(int k = i;k <= j - 1;k++)
            {
                fn[i][j] = min(fn[i][j],fn[i][k] + fn[k + 1][j] + s[j] - s[i - 1]);
                fx[i][j] = max(fx[i][j],fx[i][k] + fx[k + 1][j] + s[j] - s[i - 1]);
            }
        }
    }
    int mina = 0x3f3f3f3f,maxa = 0;
    for(int i = 1;i <= n;i++)
    {
        mina = min(mina,fn[i][i + n - 1]);
        maxa = max(maxa,fx[i][i + n - 1]);
    }
    printf("%d\n%d",mina,maxa);
    return 0;
}

2.括号匹配

给出一个的只有 ( ) [ ] 四种括号组成的字符串 s,求最多有多少个括号满足题目里所描述的完全匹配。数据满足 |s| \le 100

这是一个匹配问题,想到区间DP。

性质:两个所有括号都能被匹配的字符串,相连后括号一定也都能被匹配。

f_{i,j} 为区间 [i,j] 内完全匹配的括号数,推出转移方程:f_{i,j}=max(f_{i,j},f_{i,k}+f_{k+1,j})

这样肯定不够,因为没有任何可计算的数值。当区间 [i,j] 满足 s_i,s_j 是一对可以被匹配的括号时,f_{i,j}=f_{i+1,j-1}+2

此题不用初始化。

Code:
#include <bits/stdc++.h>
using namespace std;
int f[105][105];
char s[105];
int main()
{
    int n;
    scanf("%d",&n);
    getchar();
    scanf("%s",s + 1);
    for(int l = 1;l <= n;l++)
    {
        for(int i = 1;i <= n - l + 1;i++)
        {
            int j = i + l - 1;
            if((s[i] == '(' && s[j] == ')') || (s[i] == '[' && s[j] == ']'))
            {
                f[i][j] = f[i + 1][j - 1] + 2;
            }
            for(int k = i;k <= j - 1;k++)
            {
                f[i][j] = max(f[i][j],f[i][k] + f[k + 1][j]);
            }
        }
    }
    printf("%d",f[1][n]);
    return 0;
}

3.最多回文子串

给定一个字符串 s,求出其最多的可构成的回文字串(不要求连续),注意这里不同的回文字串只要求位置不同即可视为不同,比如 aaaaa 的最多回文子串数目是 31 个。数据保证 |s| \le 1000,答案要 mod 10007

这个题较难,首先要通过回文串的匹配性质想到区间DP。

性质:在一个回文串的两侧添加两个相同的字符,这还是一个回文串。

于是问题就转化成了找相同的字符进行匹配,需要区间DP。

f_{i,j}[i,j] 之间的回文串数量,由于回文串不要求连续,显然 f_{i,j}=f_{i+1,j} \cup f_{i,j-1}。此处要用到容斥原理f_{i+1,j} \cup f_{i,j-1}=f_{i+1,j}+f_{i,j-1}-f_{i+1,j} \cap f_{i,j-1}=f_{i+1,j}+f_{i,j-1}-f_{i+1,j-1},这就是一个转移方程。注意因为有减法,所以要先加上 mod 再计算,否则可能会有负数。

像上一题一样,当两端的字符相同时,需要更新,所以 f_{i,j}+=f_{i+1,j-1}+1。注意此处是 += 而非 =,因为区间 [i,j]s_i+[i+1,j-1]+s_j 中的回文串没有重叠的。

初始化是 f_{i,i}=1,区间端点按照 s_i 是否等于 s_j 分类,两个转移方程都已推出,时间复杂度为 O(n^2)。此题无需枚举断点,如果从这里出发思考就会超时,并且不一定能做出来。

Code:
#include <bits/stdc++.h>
using namespace std;
int f[1005][1005],mod = 10007;
char s[1005];
int main()
{
    int n;
    scanf("%d",&n);
    getchar();
    scanf("%s",s + 1);
    for(int i = 1;i <= n;i++)
    {
        f[i][i] = 1;
    } 
    for(int l = 1;l <= n;l++)
    {
        for(int i = 1;i <= n - l + 1;i++)
        {
            int j = i + l - 1;
            f[i][j] = (f[i + 1][j] + f[i][j - 1] - f[i + 1][j - 1] + mod) % mod;
            if(s[i] == s[j])
            {
                f[i][j] = (f[i][j] + f[i + 1][j - 1] + 1) % mod;
            }
        }
    }
    printf("%d",f[1][n]);
    return 0;
}

4.P3147

这个题很难想到是区间DP,并且设的DP状态也不是区间 [i,j],不过由于题目中有合并操作,所以尝试两两合并。

再看DP数组 $f$,显然是二维,那么大小一定是 $n*logn$ 左右。$logn=58$,题目中的正整数最大也就 $40$,那么两维应该分别为**位置及合并出的数值**。 通过这两点,基本确定 $f$ 的状态 $f_{i,j}$ 为左端点位置为 $j$ 且能合并出数值 $i$ 的第一个右端点。**但最终的状态是右端点位置 $+1$**,解释一下为什么。 已知两个 $i-1$ 能合并出 $i$,那么在合并出 $i-1$ 的右端点的**右侧**继续合并出另一个 $i-1$,就可以合并出 $i$ 了。如果不重点注意右侧,那么下一个位置将还是原右端点。由于存储的是第一个右端点,所以一定是 $j$ 右侧的第一个合并出 $i$ 的位置。 $f[i][j] = f[i - 1][f[i - 1][j]]

最难的部分已经完成了,现在思考如何初始化和记录答案。初始化为从位置 i 右侧第一个合并出 a(输入的正整数)的位置是 i+1,而答案为最大能合并出来的数值,也就是要有 f_{maxa,x} 不为 0,那么在状态转移的循环里判断一下 f_{i,j} 的值就可以记录。

代码虽然很短,但是思路很难想,倍增思想。此题不是严格的区间DP,它其实就是从 [j+1,n] 里找 i

Code:
#include <bits/stdc++.h>
using namespace std;
int f[60][300005]; 
int main()
{
    int n;
    scanf("%d",&n);
    for(int i = 1;i <= n;i++)
    {
        int a;
        scanf("%d",&a);
        f[a][i] = i + 1;
    }
    int maxa = 0;
    for(int i = 1;i <= 58;i++)
    {
        for(int j = 1;j <= n;j++)
        {
            if(f[i][j] == 0)
            {
                f[i][j] = f[i - 1][f[i - 1][j]];
                //如果记录的是右端点,这一行f[i-1][f[i-1][j]]就还是f[i-1][j],显然错误
            }
            if(f[i][j] != 0)
            {
                maxa = i;
            }
        }
    }
    printf("%d",maxa);
    return 0;
}

5.P4302

f_{i,j}[i,j] 的最小长度,初始化 f_{i,i}=1,显然 f_{i,j}=min(f_{i,j},f_{i,k}+f_{k+1,j})

一段字符串除了合并,还可以压缩自身,所以考虑压缩。依旧枚举 k,并考虑将 [i,k] 这一段作为循环的字符串折叠。那么需要满足的条件是循环字符串的长度是 l 的因数,且每一段都相同(能作为循环的部分)

当满足条件后,需要更新 f_{i,j} 的值,根据题目格式,折叠成如下形式:

l/(k-i+1)(f[i]~f[k])

也就是在 f_{i,k} 的基础上 +2(两个括号)再加上 l/(k-i+1) 这个数字的长度。需要预处理 1-100 的数字位数,便于调用。

最后,根据此题我们知道了,涉及循环的区间的问题也可以用区间DP。

补充知识:当一个函数 fun 传入的参数为 fun(a+i)a 是一个数组),函数内接收到的数组(假设为 b[])就是数组 a 从下标 i 开始直到结束的数值。所以下面代码中的 fun(i,l,x) 可以写成 fun(s+i,l,x),就能免掉函数内各种 +l

Code:
#include <bits/stdc++.h>
using namespace std;
char s[105];
int n,a[105],f[105][105];
int fun(int l,int m,int x)
{
    for(int i = l + x;i <= l + m - 1;i++)
    {
        if(s[i] != s[(i - l) % x + l])
        {
            return 0;
        }
    }
    return 1;
}
int main()
{
    memset(f,0x3f3f3f,sizeof(f));
    scanf("%s",s);
    n = strlen(s);
    for(int i = 0;i <= n - 1;i++)
    {
        f[i][i] = 1;
    }
    for(int i = 1;i <= 9;i++)
    {
        a[i] = 1;
    }
    for(int i = 10;i <= 99;i++)
    {
        a[i] = 2;
    }
    a[100] = 3;
    for(int l = 2;l <= n;l++)
    {
        for(int i = 0;i <= n - l;i++)
        {
            int j = i + l - 1; 
            for(int k = i;k <= j - 1;k++)
            {
                f[i][j] = min(f[i][j],f[i][k] + f[k + 1][j]);
            }
            for(int k = i;k <= j - 1;k++)
            {
                int x = k - i + 1;
                if(l % x != 0)
                {
                    continue;
                }
                if(fun(i,l,x) == 1)
                {
                    f[i][j] = min(f[i][j],f[i][k] + a[l / x] + 2);
                }
            }
        }
    }
    printf("%d",f[0][n - 1]);
    return 0;
}

6. P2470

此题和P4302很像,但它可以嵌套循环,所以更为复杂。(是个紫题)

依旧使用区间DP,设 f_{i,j,k}(k=0,1) 是区间 [i,j] 是否包含M的最小压缩长度。

思路:每次先根据断点通过相加取最小值,然后将区间 [i,j] 折叠,如果满足折叠的两部分相同,则有式子 f_{i,j,1}=min(f_{i,j,1},min(f_{i,x-1,1}-1,f_{i,x-1,0})+2) 更新 f_{i,j,1}xi+l/2,最后的结果为 min(f_{1,n,0},f_{1,n,1})

看起来很对,实际上呢?

Hack:(样例1)

aaaaaaa

正确答案应该是压缩成aaaRa,也就是 5,但程序输出了 6。所以上述非正解,到这里我就一直 70pts 不会了,太菜所以看了题解。

发现自己之前是压缩错误,没有考虑题目中“R重复从上一个M开始的解压结果”,直接又可以压缩的就加入R,实际上R前面已经被压缩过了。

Code:
#include <bits/stdc++.h>
using namespace std;
char s[105];
int n,f[105][105][3];
int fun(int x,int y)
{
    if((y - x + 1) % 2 != 0)
    {
        return 0;
    }
    int m = (y - x + 1) / 2;
    for(int i = x;i <= x + m - 1;i++)
    {
        if(s[i] != s[i + m])
        {
            return 0;
        }
    } 
    return 1;
}
int main()
{
    scanf("%s",s + 1);
    n = strlen(s + 1);
    for(int i = 1;i <= n;i++)
    {
        f[i][i][1] = 1;
        f[i][i][2] = 2;
    }
    for(int l = 2;l <= n;l++)
    {
        for(int i = 1;i <= n - l + 1;i++)
        {
            int j = i + l - 1;
            f[i][j][1] = 1e9;
            f[i][j][2] = 1e9;
            if(fun(i,j) == 1)
            {
                int x = (i + j) / 2;
                f[i][j][1] = min(f[i][j][1],f[i][x][1] + 1); 
            } 
            for(int k = i;k <= j - 1;k++)
            {
                f[i][j][1] = min(f[i][j][1],f[i][k][1] + j - k);
                f[i][j][2] = min(f[i][j][2],min(f[i][k][1],f[i][k][2]) + min(f[k + 1][j][1],f[k + 1][j][2]) + 1); 
            }
        }
    }
    printf("%d",min(f[1][n][1],f[1][n][2]));
    return 0;
}