复习:矩阵快速幂

· · 个人记录

前言

emmm太久了忘了许多 写笔记来复习一下

概念

矩阵乘法

什么是矩阵乘法

给你两个矩阵a,b

则令c=a*b

c_n=a_n$,$c_m=b_m \sum\limits_{i=1}^{c_n}\sum\limits_{j=1}^{c_m} c_{i,j}\sum\limits_{k=1}^{a_m}a_{i,k}*b_{k,j}

两个矩阵做乘法的前提:a_m=b_n

找不到图片了 网上自取

抽象于行乘列即可

Code

node init(int n,int m)
{
    node c;
    c.n=n;
    c.m=m;
    for(int i=1;i<=c.n;i++)
        for(int j=1;j<=c.m;j++)
            c.r[i][j]=0;
    return c;
}
node operator *(node a,node b)
{
    node c=init(a.n,b.m);
    for(int i=1;i<=c.n;i++)
        for(int j=1;j<=c.m;j++)
            for(int k=1;k<=a.m;k++)
                c.r[i][j]=(c.r[i][j]+a.r[i][k]*b.r[k][j])%mod;
    return c;
}

初始矩阵

node clear(int n)
{
    node c=init(n,n);
    for(int i=1;i<=n;i++)
        c.r[i][i]=1;
    return c;
}

初始矩阵满足 P 满足任意矩阵 a 使得P*a=a

注意初始P_n=P_m=a_m

性质

定义矩阵a,b,c

注意 矩阵不存在交换律

矩阵快速幂

因为矩阵满足结合律 因此可以直接快速幂优化时间

node Qpow(node a,ll x)
{
    node sum=clear(a.m);
    while(x>0)
    {
        if(x&1) sum=sum*a;
        a=a*a;
        x/=2;
    }
    return sum;
}

矩阵优化递推 P1939

令初始A数组为(1,1,1) 把其看成a_{i-1},a_{i-2},a_{i-3} 思考如何推成a_{i},a_{i-1},a_{i-2}

发现构造矩阵:

b=\begin{bmatrix} 1&1&0\\0&0&1\\1&0&0 \end{bmatrix}

a*b即可转移一位

答案就是a*b^n

矩阵优化快速幂即可

int main()
{
    a.n=1,a.m=3;
    a.r[1][1]=1,a.r[1][2]=1,a.r[1][3]=1;
    b.n=b.m=3;
    b.r[1][1]=1,b.r[1][2]=1,b.r[1][3]=0;
    b.r[2][1]=0,b.r[2][2]=0,b.r[2][3]=1;
    b.r[3][1]=1,b.r[3][2]=0,b.r[3][3]=0;
    scanf("%d",&g);
    while(g--)
    {
        scanf("%d",&n);
        printf("%lld\n",(a*Qpow(b,n-3)).r[1][1]);
    }
    return 0;
}

矩阵优化 DP P4838

定义f_{i,0}表示后缀为 \texttt{a} 的方案数

定义f_{i,1}表示后缀为 \texttt{aa} 的方案数

定义f_{i,2}表示后缀为 \texttt{b} 的方案数

容易得到简单 dp

    for(int i=3;i<=n;i++)
    f[i][0]=f[i-1][2],
    f[i][1]=f[i-1][0],
    f[i][2]=f[i-1][0]+f[i-1][1]+f[i-1][2];

容易构造简单矩阵

\begin{bmatrix}0&1&1\\0&0&1\\1&0&1 \end{bmatrix}

套上矩阵快速幂即可快速求解

Code

int main()
{
    a.n=1,a.m=3;
    a.r[1][1]=1,a.r[1][2]=1,a.r[1][3]=2;
    b.n=b.m=3;
    b.r[1][1]=0,b.r[1][2]=1,b.r[1][3]=1;
    b.r[2][1]=0,b.r[2][2]=0,b.r[2][3]=1;
    b.r[3][1]=1,b.r[3][2]=0,b.r[3][3]=1;

    scanf("%d",&g);
    while(g--)
    {
        scanf("%d",&n);
        if(n==1) printf("2\n");
        else 
        {
            node ans=a*Qpow(b,n-2);
            printf("%lld\n",(ans.r[1][1]+ans.r[1][2]+ans.r[1][3])%mod);
        }
    }
    return 0;
}