浅谈循环矩阵的乘法和快速幂

· · 个人记录

前言

前置技能:矩阵乘法,矩阵快速幂

当然你不会的话也不会点进来(滑稽)

今天上午的HNOI模拟赛中,T1是这么一道题目:

有一个长度为n的环,执行s次操作,在一次操作中,

对于每一个数,它变为它左边的数乘上l以及它本身以及它右边的数乘上r的和。

求最后每一个位置上的数是多少。(计算时左边和右边的数都是上一次的数)

最后结果模上10^xl,r,x都为给定的常数

n\leq1000,s\leq2^{30}

很容易想到用矩阵快速幂来维护,假设我们现在有4个数字需要变换,设f[i][j]表示当前已经变换了i次,第j位上的数字是多少,有递推式:

\begin{aligned}&\begin{bmatrix}f[i-1][1]&f[i-1][2]&f[i-1][3]&f[i-1][4]\end{bmatrix}\times\begin{bmatrix}1&l&0&r\\r&1&l&0\\0&r&1&l\\l&0&r&1\end{bmatrix}\\=&\begin{bmatrix}f[i][1]&f[i][2]&f[i][3]&f[i][4]\end{bmatrix}\end{aligned}

但是这样子做是O(n^3log_2s),妥妥的超时(在没有O(wys)优化的情况下)

这时我们发现,转移矩阵是一个循环矩阵。

循环矩阵是什么?

下面内容引自百度百科,侵删

​ 在线性代数中,循环矩阵是一种特殊形式的Toeplitz矩阵,它的行向量的每个元素都是前一个行向量各元素依次右移一个位置得到的结果。

简单地说,就是以一行不断向前/后变换的形式出现。循环矩阵有一些性质

假设矩阵a,b都是循环矩阵,那么:

循环矩阵的乘法和矩阵快速幂

根据性质2,那么这个转移矩阵无论自乘多少次,它还是一个循环矩阵,所以说,我们理论上只要知道第一行的最终形态,就可以知道整个矩阵了。那如何求出它的下一个形态呢?(转移矩阵为T

假设g[i]为当前矩阵第一行的第i个,g'[i]为下一个形态的矩阵的第i个。则有:

\begin{aligned}g'[1]=&T[1][1]\times T[1][1]+T[1][2]\times T[2][1]+T[1][3]\times T[3][1]+T[1][4]\times T[4][1]\\=&g[1]\times g[1]+g[2]\times g[4]+g[3]\times g[3]+g[4]\times g[2]\end{aligned}

这里可以得出这个结果是根据了循环矩阵的定义。

那么其他的求法也是雷同的,比如g'[2],其他的交给读者自己完成:

\begin{aligned}g'[2]=&T[1][1]\times T[1][2]+T[1][2]\times T[2][2]+T[1][3]\times T[3][2]+T[1][4]\times T[4][2]\\=&g[1]\times g[2]+g[2]\times g[1]+g[3]\times g[4]+g[4]\times g[3]\end{aligned}

于是我们得出了这样的规律:

g'[x]=\sum\limits_{(i+j-2)\%n+1=x}g[i]*g[j]

很显然,我们的初始矩阵S也是一个只有一行的循环矩阵,同样适用于上面的规律,只不过是两个不同的矩阵。

tmp[x]=\sum\limits_{(i+j-2)\%n+1=x}S[i]*T[j]

其中tmp即本次矩阵乘法的结果矩阵

那么之前所讲的例题再用矩阵快速幂就变成O(n^2log_2s)的了:

#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
using std::min; using std::max;
using std::swap; using std::sort;
typedef long long ll;

template<typename T>
void read(T &x) {
    int flag = 1; x = 0; char ch = getchar();
    while(ch < '0' || ch > '9') { if(ch == '-') flag = -flag; ch = getchar(); }
    while(ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar(); x *= flag;
}

const int N = 1e3 + 10;
int n, s, l, r, x, p = 1;
int S[N], T[N], tmp[N];

void mul(int S[], int T[]) {
    memset(tmp, 0, sizeof tmp);
    for(int i = 1; i <= n; ++i)
        for(int j = 1; j <= n; ++j)
            (tmp[(i + j - 2) % n + 1] += 1ll * S[i] * T[j] % p) %= p;
    memcpy(S, tmp, sizeof(tmp));
}//矩阵乘法

int main () {
    read(n), read(s), read(l), read(r), read(x);
    for(int i = 1; i <= x; ++i) p *= 10;
    for(int i = 1; i <= n; ++i) read(S[i]);
    T[1] = 1, T[2] = l, T[n] = r;
    for(; s; s >>= 1, mul(T, T)) if(s & 1) mul(S, T);//快速幂
    for(int i = 1; i <= n; ++i) printf("%d ", S[i]);
    return 0;
}