数位DP

· · 个人记录

数位DP

问题类型是求解小于等于某个给定n的、数位符合一定条件的的正整数的个数。注意如果题目要求是求解某个正整数区间上的、符合一定条件的数的个数,则可以轻易化归到上述问题。

我们试图从最高一位开始,逐位取数。我们考虑在取每一位数的时候,能够有多少选择。注意我们需要取的数需要满足两个条件:

  1. 小于等于某个给定的n
  2. 数位符合题目给定的条件。

对于第一个条件,我们需要注意到对于任何小于n的正整数,它从高到低一定存在某一数位,使得这一数位小于n对应的数位,且其之前的所有数位都和n上的相等。我们不妨称这一数位为破冰数位(一个临时杜撰的叫法)。换言之,在破冰数位之前,所有的数位都尽可能地取了最大的值。例如,假如n8763,对于8742来说,十位上的4就是一个破冰数位。一个容易的结论是,如果我们列举的数不存在破冰数位,那么它应该和n相等。

问题在于,我们在取某一位数的时候,此前是否已经出现了破冰数位。很明显,如果此前尚未出现破冰数位,那么在选取这一数位的时候,我们的上界是受限的,即不能超过n的同一数位;而如果此前已经出现了破冰数位,那么我们的上界就是简单的9,不存在其他限制。

由此可见,破冰数位的是否已经出现会影响到我们的对数位的选择,因此需要将这一状态加入考虑。

对于第二个条件,这是每道题目最需要思考的部分。例如,windy数的定义是“不含前导零且相邻两个数字之差至少为2的正整数被称为 windy 数”。由于这里的限制条件仅包含“相邻数字”,所以在每一位影响决策的只有上一位的数字的选取,这一状态需要被考虑才能进行决策。而在萌数一题中,可以推出“非萌数”的定义定价于“所有数位均与前一位、前两位都不同”的正整数,因此在每一位影响决策的有上两位的数字的选取。

注意当我们将上一位数字的选取作为状态参数考虑的时候,需要区分数字内部的0和前导0,因为后者并不是我们真正进行选择的结果,往往并不参与到限制条件的运算中。例如,在windy数一题中,不能简单地以“上一位选取是0”开始递归,因为这样会误将所有的12开头的数字均判断为非windy数。正确的做法是用一个特殊的数字来表示前导零,如10,并在后面的程序中对其进行特殊判断;更好的做法是保证这一选取不会和题目的限制冲突,例如在windy数中选取11表示前导0,即可(巧妙地)绕过对windy数身份的判断。而当我们已知上一位是前导0的情况下,假如在这一位又选取了0,我们需要将它标记为前导零再进行递归。

结束了上述论证后,我们不难发现, 对于每个固定的n,在确定了所选数位序号、破冰数位的是否存在、前序数位对所选数位会产生影响的状态参数后,所得到的结果总是一样的。例如,在windy数一题中,假如n8763,在选择724\_和选择694\_时,由于都在选择个位、十位都是4、之前均已经出现了破冰数位,其后的选择方案数无疑是一样的。这就给了我们使用记忆化搜索的可能,或者也可将整个问题变形为递推,问题到这就是老生常谈了。

在实现中需要注意的细节:

  1. 由于所有的操作都基于数位,常见做法是将给定的n逐位存入一个vector中再进行运算。甚至有时题目给的n就很大,读入的时候直接用stringvector存入。但需要注意n的数位在容器中的顺序,用两种方法得到的是相反的。
  2. 递归的初始值可以选用当选取位数越界时返回1的方案。因为当选取位数越界时,表示函数在边界位被调用,而在边界位的情形显然表示所有的数位已经被枚举完了,故而是一种方案。
  3. 记忆数组会随n的变化变化,因此在计算新的n的值的时候需要初始化记忆数组。

下面给出windy数的实现:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

ll f[15][15][2];

ll dfs(int x, int st, int op, vector<int> &dim) { 
    // op = 0 means <, op = 1 means ==
    // st = 10 when it is leading zero
    if (x == -1) return 1; 

    ll &ret = f[x][st][op];
    if (ret != -1) return ret;
    ret = 0;

    int m = op ? dim[x] : 9; 
    for(int i = 0; i <= m; i++) if (abs(st - i) >= 2 || st == 10){
        ret += dfs(x - 1, (st == 10 && i == 0) ? 10 : i, op & (i == m), dim);
    }
    return ret;
}

ll solve(int x) {
    vector<int> dim;
    while(x) {
        dim.push_back(x % 10);
        x /= 10;
    }
    memset(f, -1, sizeof(f));
    return dfs(dim.size() - 1, 10, 1, dim); 
}

int main() {
#ifdef D
    freopen("2567.in", "r", stdin);
#endif
    int l, r; cin >> l >> r;
    ll ret = solve(r) - solve(l - 1);
    cout << ret << endl;
    return 0;
}

下面给出萌数的实现. 本题的数据范围较大, 因此L-1不能直接运算, 需要特殊判断左侧边界的性质.

#include <bits/stdc++.h>
using namespace std;
#define rep(i,from,to) for(register int i=from;i<=to;++i)
#define For(i,to) for(register int i=0;i<(int)to;++i)
typedef long long ll;

const long long M = 1000000007;

vector<int> l, r;
ll f[1024][11][11][2];

inline void read(vector<int> &res){
    res.clear();
    char c = getchar();
    while(c < '0' || c > '9') c = getchar();
    while(c >= '0' && c <= '9') {
        res.push_back(c - '0');
        c = getchar();
    }
}

ll dp(int x, int st1, int st2, int op, vector<int> &dim) {
    if(x == (int)dim.size()) return 1;
    ll &ret = f[x][st1][st2][op];
    if (ret != -1) return ret;
    ret = 0;

    int m = op ? dim[x] : 9;

    for(int i = 0; i <= m; i++) if (st1 != i && st2 != i){
        if (st2 == 10 && i == 0) {
            ret += dp(x + 1, 10, 10, op & (i == m), dim);
        } else {
            ret += dp(x + 1, st2, i, op & (i == m), dim);
        }
        ret %= M;
    }

    return ret;
}

ll solve(vector<int> &dim) {
    memset(f, -1, sizeof(f));
    ll res = dp(0, 10, 10, 1, dim); 
    // res is count of all the un-meng numbers k s.t. 0 <= k <= n, n represented by dim
    ll s = 0;
    For(i, dim.size()) {
        int u = dim[i];
        s = (s * 10) % M;
        s = (s + u) % M;
    }
    s = s - res + 1;
    s %= M; s += M; s %= M;
    return s;
}

bool judge(vector<int> &dim) {
    for(int i = 0, s = dim.size(); i < s - 2; ++i) {
        if (dim[i] == dim[i + 2] || dim[i] == dim[i + 1] || dim[i + 1] == dim[i + 2]) return true;
    }
    return false;
}

int main() {
#ifdef D
    freopen("3413.in", "r", stdin);
#endif
    read(l); read(r);
    ll ans = solve(r) - solve(l) + judge(l);
    ans %= M; ans += M; ans %= M;
    cout << ans << "\n";
    return 0;
}