P3216 [HNOI2011]数学作业

· · 题解

『矩阵乘法』的好题!

题意看起来十分简单,就是要求n首尾拼起来 \mod m \ 的值。

可是n的值那么大,感觉无从下手。。。

我们仔细来分析一下暴力计算的时候,它的流程到底是什么样的? 对于当前的值val,加入一个v后,val的值变成:val*10^k+vv的位数)。

我们应该知道一件事:矩阵乘法经常用来优化线性递推式。

对于上面的式子,其实就是:a[i]=a[i-1]*10^k+a[i].如果k不变的话,那好办,我定义一个矩阵为(dp[i-1],i,1),通过与另一个矩阵相乘得到到(dp[i],i+1,1)的矩阵,最后答案取第一项即可。

但问题来了,k并不是一个定值怎么办?

也好办,发现一个重要信息:k在某一段内是不会变的,于是我们有一个笨拙的方法:按k分为若干段,将每一段内的答案用矩阵单独处理,再合并。

然后我们发现,实际上最多也就分成18段,所以我们知道:它就是正解!

于是对于每一段,我们可以推出另一个矩阵就是:

\ 10^k \ 0 \ 0 \ \ 1 \ \ \ 1 \ 0 \ \ 0 \ \ \ 1 \ 1

只有k是需要改变的。

总结:矩阵乘法中分段处理的思想是十分常见的(比如NOI2020Day1T1),希望通过这到题,自己能够有所体会。

code:
#include<cmath>
#include<cstring>
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<queue>
#include<vector>
#include<map>
#include<bitset>
#define int long long
#define reg register
using namespace std;
const int maxn=4e5+5;
//const int mod=1e9+7;
const int INF=0x3f3f3f3f;
inline int read()
{
    int x=0,f=1;
    char ch=getchar();
    while(ch<'0' || ch>'9')
    {
        if(ch=='-')
        f=-1;
        ch=getchar();
    }
    while(ch>='0' && ch<='9')
    {
        x=x*10+ch-'0';
        ch=getchar();
    }
    return x*f;
}
struct Matrix
{
    int a[5][5];
}ans,wns;
int n,mod;
int v[25];
Matrix Work(Matrix x,Matrix y)
{
    Matrix c;
    memset(c.a,0,sizeof(c.a));
    for(reg int i=1;i<=3;i++)
    {
        for(reg int j=1;j<=3;j++)
        {
            for(reg int k=1;k<=3;k++)
            {   
                c.a[i][j]=(c.a[i][j]+x.a[i][k]%mod*y.a[k][j]%mod)%mod;
            }
        }
    }
    return c;
}   
Matrix mul(Matrix x,int y)
{
    Matrix res;
    memset(res.a,0,sizeof(res.a));
    for(reg int i=1;i<=3;i++)
        res.a[i][i]=1;
    while(y)
    {
        if(y & 1)
        res=Work(x,res);
        x=Work(x,x);
        y>>=1;
    }
    return res;
}
signed main()
{
    n=read(),mod=read();
    v[0]=1;
    for(reg int i=1;i<=18;i++)
    v[i]=v[i-1]*10;
    ans.a[1][1]=0,ans.a[1][2]=1,ans.a[1][3]=1;
    wns.a[2][1]=wns.a[2][2]=wns.a[3][2]=wns.a[3][3]=1;
    for(int i=1;;i++)
    {
        wns.a[1][1]=v[i]%mod;
        int tot=min(n,v[i]-1)-v[i-1]+1;
        Matrix val=mul(wns,tot);
        ans=Work(ans,val);
        if(v[i]-1>=n)
        break;
    }
    printf("%lld\n",ans.a[1][1]%mod);
    return 0;
}
/*
1 0 10
1 1
1 1
1 1
1 1
1 1
1 1
1 1
1 1
1 1
1 1
*/