数位 DP 学习笔记

· · 算法·理论

前言

本文主要参考这篇文章,文章作者的详细讲解使得作者对数位 DP 有了初步的了解。

当然,本文的讲解方式也会与这篇文章类似。

简介

数位 DP 解决在区间 [A,B] 中,满足特定条件的数的个数,A,B 总是很大,往往会达到 10^9 甚至 10^{18} 的级别。

数位 DP 往往对于数的每一位进行分析与处理,我们一般用记忆化搜索而非递推进行求解。

处理

区间 [l,r] 的答案可以视作区间 [0,r] 的答案减去区间 [0,l-1] 的答案。这样我们在处理的过程之中,只需要考虑区间上界的问题,而不用考虑区间下界。

记搜的大概过程就是从数的首位向下搜索,到末位得到最终方案数,一层一层向上返回答案并累加,最后从起点得到最终答案。

那么对于每一层递归,应该如何正确的设计状态呢?

其余的这样那样的参数视情况而定。

参数解析

状态记录

别忘了记忆化搜索的特别之处:状态记录。

DP 是在记搜的框架下进行的,因此每当找到一种情况,我们就可以这种情况记录下来,等递归下去,遇到相同的情况时直接使用当前记录的值。这样可以大大减少不必要的递归层数。

记录数组的下标表示的是一种状态,只要当前的状态和之前搜过的某个状态完全一样,我们就可以直接返回原来已经记录下来的此状态的值。

需要注意的是,当当前有某些限制,例如存在 limit,lead 标记时,不可以记录或使用存储下来的值。

原因很简单,若存在限制标记,当前的状态是不完整的,所以是不可以搜到某些取值范围的,记录下来的答案自然也是存在偏差的。

例题

P2657 [SCOI2009] windy 数

经典模板题。

直接上记搜代码,内有注释。

int DFS(int pos,int pre,bool lead,bool limit){//pos:当前递归层数,pre:上一位数码,lead:前导0限制,limit:最高位限制
    if(pos>len)//递归尽头
        return 1;
    if(dp[pos][pre]!=-1&&(!lead)&&(!limit))//可以直接取用答案
        return dp[pos][pre];
    int tmp=0,digit=9;
    if(limit)//存在最高位限制
        digit=a[len-pos+1];//最多只能取到原数当前数位
    for(int i=0;i<=digit;i++){
        if(abs(pre-i)<2)//选取数码与上一位数码之差小于2,不满足条件
            continue;
        bool nxt=bool((i==digit)&limit);//下一位的最高位限制
        if((!i)&&lead)//有前导0
            tmp+=DFS(pos+1,15,1,nxt);//上一位数码取15,这样无论如何和下一位数码的差都不小于2
        else if(i&&lead)
            tmp+=DFS(pos+1,i,0,nxt);
        else
            tmp+=DFS(pos+1,i,0,nxt);
        //以上是没有前导0
    }
    if(!limit&&!lead)
        dp[pos][pre]=tmp;//记录答案
    return tmp;
}

P4127 [AHOI2009] 同类分布

有点难度的题。

参数设计简单。位数参数 pos,数位和 sum,最终数 num,最高位限制 limit

写着写着,我们发现了一个问题。num 理论最大值为 10^{18},所以记录数组不能记录带有 num 的状态。考虑取模,但又发现 sum 是不固定的,无法实时取模。

那怎么办?约定模数呗。

枚举模数,统计所有模数下的答案之和。答案判定合法即在递归尽头的 sum 等于模数,且不断取模的 num 在递归尽头的值为 0

贴个完整代码:

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=20,M=200;
int A,B,len,ans;
int a[N];
int dp[N][M][M];
int DFS(int pos,int sum,int num,bool limit,int mod){//当前位,数位和,当前取模结果,最高位限制,约定模数
    if(pos>len){//递归尽头
        if(num==0&&sum==mod)//取模结果为0且数位和等于模数
            return 1;
        return 0;
    }
    if(dp[pos][sum][num]!=-1&&(!limit))//取用答案
        return dp[pos][sum][num];
    int digit=9,tmp=0;
    if(limit)
        digit=a[len-pos+1];
    for(int i=0;i<=digit;i++){
        bool nxt=(limit&(i==digit));
        tmp+=DFS(pos+1,sum+i,(num*10+i)%mod,nxt,mod);
    }
    if(!limit)//记录
        dp[pos][sum][num]=tmp;
    return tmp;
}
int part(int x,int mod){//拆分数位
    len=0;
    while(x){
        a[++len]=x%10;
        x/=10;
    }
    memset(dp,-1,sizeof(dp));
    return DFS(1,0,0,1,mod);
}
signed main(){
    cin>>A>>B;
    for(int i=1;i<=162;i++)//枚举约定模数
        ans+=part(B,i)-part(A-1,i);
    cout<<ans;
    return 0;
}

CF1073E Segment Sum

数位 DP 和状态压缩。

考虑用状压维护每个数出现的状态。

dp_{pos,state} 表示在第 pos 位,状态为 state 所给贡献。

注意到题目要求维护的是满足要求的数之和,而非个数,所以不能只用常规的 dp_{pos,state}=\sum dp_{pos+1,state'} 进行转移,其中 state' 表示添加 x 之后的状态。而是要分成 2 部分进行考虑:第 pos 位本身所给贡献,以及 pos 之后的位总共的贡献。

后者自然是 \sum dp_{pos+1,state'}。前者如何求解?

pos 位是 x,那么其个体给予答案的贡献就为 x\times 10^{len-pos}\times cnt,其中 cnt 表示后面有多少种合法的答案。因为只靠 dp 数组无法维护 cnt,所以考虑多加一个数组 gg_{pos,state} 表示在 pos 位,状态为 state 后面有多少种合法答案。

整合一下,得到 dp_{pos,state}=\sum dp_{pos+1,state'}+x\times 10^{len-pos}\times g_{pos+1,state'}

当然,还有 g_{pos,state}=\sum g_{pos+1,state'}

思路就这样,只是代码细节超多。

代码:

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=25,M=2e3+5,Mod=998244353;
int A,B,K,len;
int a[N],power[N];
int dp[N][M],g[N][M];
struct node{
    int res,cnt;//dp[][],g[][]
};
void init(){//预处理10的次幂 
    power[0]=1;
    for(int i=1;i<=18;i++)
        power[i]=power[i-1]*10%Mod;
    return;
}
node DFS(int pos,int state,bool lead,bool limit){//由于每次向上一层返回状态时需要用dp,g两个状态,所以DFS的类型是node 
    if(__builtin_popcount(state)>K)//含不同数字超过K,剪枝 
        return {0,0};
    if(pos>len)
        return {0,1};
    if(!limit&&!lead&&dp[pos][state]!=-1)//常规调用答案 
        return {dp[pos][state],g[pos][state]};
    int digit=9;
    node tmp={0,0};
    if(limit)
        digit=a[len-pos+1];
    for(int i=0;i<=digit;i++){
        bool nxtlim=bool(limit&(i==digit)),nxtlea=bool(!i&&lead);
        int nxtsta=0;//当前是前导0,状态为0 
        if(!lead||i)//不是前导0,状态中表示i的那一位为1,表示出现过i 
            nxtsta=(state|(1<<i));
        node nxt=DFS(pos+1,nxtsta,nxtlea,nxtlim);
        tmp.res=tmp.res+nxt.res+nxt.cnt*power[len-pos]%Mod*i%Mod,tmp.res%=Mod;
        tmp.cnt=tmp.cnt+nxt.cnt,tmp.cnt%=Mod;
        //记录答案
        //这里注意一个点:直接处理出node的值nxt,计算答案时直接调用nxt,不然的话如果计算答案时反复调用,例如tmp.res=tmp.res+DFS(pos+1,nxtsta,nxtlea,nxtlim).res+...会T到起飞 
    }
    if(!limit&&!lead)//常规记录答案 
        dp[pos][state]=tmp.res,g[pos][state]=tmp.cnt;
    return tmp;
}
int part(int x){
    len=0;
    while(x){
        a[++len]=x%10;
        x/=10;
    }
    memset(dp,-1,sizeof(dp));
    memset(g,-1,sizeof(g));
    return DFS(1,0,1,1).res;
}
signed main(){
    init();
    cin>>A>>B>>K;
    cout<<((part(B)-part(A-1))%Mod+Mod)%Mod;//注意减法取模 
    return 0;
}

P4067 [SDOI2016] 储能表

发现这个 \max 有点难搞。考虑将题意转化为:找到 i\oplus j>k 的个数 cnt 和总和 sum,答案即为 sum-k\times cnt

发现数据范围很大。考虑数位 DP。不一样的是这次的拆位是二进制。考虑 3 重限制。两重是要看当前的 i,j 是否顶满,这个是常规操作。最后一重是看 i\oplus j 的值是不是仍然卡在 k 上。(如果当前 i\oplus j 的值大于 k,则最后一定会大于 k,这个限制就是 0;否则就是 1,因为后面仍然有可能 <k,这是不合法的。)最后是算贡献。个数直接加即可。和应该是加上上一位返回过来的和再加上上一位返回回来的个数乘上当前这一位的异或值乘上当前这一位的幂次权值,其中当前这一位的异或值乘上当前这一位的幂次权值是当前这一位的新增贡献,相当于我们拆位算贡献(幂次权值即 2 的几次方)。

代码可能和平常的数位 DP 模板稍有改动,原因是 i,j 位数不统一所以不能按常规方法去得到当前这一位的值。

#include<bits/stdc++.h>
#define int long long
#define pii pair<int,int>
#define fi first
#define se second
using namespace std;
const int N=75;
int T,n,m,k,p,a[N],b[N],pw[N];
pii dp[N][2][2][2];
pii DFS(int pos,bool lim1,bool lim2,bool lim3){
    if(pos<0)return {1,0};
    if(dp[pos][lim1][lim2][lim3].fi!=-1)return dp[pos][lim1][lim2][lim3];
    int dg1=1,dg2=1,dg3=(k>>pos&1);pii tmp={0,0};
    if(lim1&&!(n>>pos&1))dg1=0;
    if(lim2&&!(m>>pos&1))dg2=0;
    for(int i=0;i<=dg1;i++){
        for(int j=0;j<=dg2;j++){
            if(lim3&&(i^j)<dg3)continue;
            pii tt=DFS(pos-1,lim1&(i==dg1),lim2&(j==dg2),lim3&((i^j)==dg3));
            tmp.fi=(tmp.fi+tt.fi)%p,tmp.se=(tmp.se+tt.fi*(i^j)%p*pw[pos]%p+tt.se)%p;
        }
    }
    return dp[pos][lim1][lim2][lim3]=tmp;
}
int part(){
    memset(dp,-1,sizeof(dp));
    pii t=DFS(61,1,1,1);
    return (t.se-t.fi*(k%p)%p+p)%p;
}
signed main(){
    cin>>T;
    while(T--){
        cin>>n>>m>>k>>p,n--,m--;
        pw[0]=1;for(int i=1;i<=60;i++)pw[i]=pw[i-1]*2ll%p;
        cout<<part()<<"\n";
    }
    return 0;
}