转置原理小记

· · 个人记录

参考资料 : 候选队论文 2020,陈宇 :《转置原理的简单介绍》

转置原理 : 定义 & 概述

给定过程矩阵 A 以及输入向量 a ,求解输出向量 b=Aa 的算法,被称为线性算法。(又称线性变换)

写成更显式的求和,即 :

b_k=\sum\limits_{i=0}A_{k,i}a_i

这里我们认为输入向量是变量,矩阵 A 是常量。可以将 A_{k,i} 理解为“该算法中 a_ib_k 的贡献系数”。

为了方便处理,约定矩阵 A 为方阵,否则也只需适当补零,不影响理论推导。

形如 b=Aa 的算法,与形如 b'=A^Ta' 的算法互为转置。

转置原理给出了互为转置的两个线性算法之间的转化方法。

对于矩阵 A ,将其表示为若干初等矩阵的乘积,得 :

A=E_1E_2...E_m

在对某个向量执行线性算法时,只需将表示中的初等矩阵逐个乘上去。

根据上文中的转置定理,则有 :

A^T=E_m^T...E_2^TE_1^T

这说明,根据 A 的初等矩阵表示法,容易得到 A^T 的初等矩阵表示法。

倒序逐乘初等矩阵的转置即可以同样的时间复杂度完成转置算法。

当然,在实际操作中,为了方便,我们不会真的以初等矩阵为单位来拆分并转置,而是以整合后的矩阵为单位描述算法。

在下文中,我们将介绍一系列经典线性算法及其转置。

线性算法的例子

观察 \rm DFT 的矩阵 :

\begin{bmatrix} 1&1&1&...&1\\ 1&\omega_n^1&\omega_n^2&\dots&\omega_n^{n-1}\\ 1&\omega_n^2&\omega_n^4&\dots&\omega_n^{2(n-1)}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ 1&\omega_n^{n-1}&\omega_n^{2(n-1)}&\dots&\omega_n^{(n-1)^2}\\ \end{bmatrix}

该矩阵关于主对角线对称,故 \rm DFT 算法的转置恰是其本身。

对于 \rm IDFT 算法也有类似结论。

若将 H=F*G 中的 F 看成输入,G 看成常量,则有 :

H_k=\sum\limits_{i=0}G_{k-i}F_i

可知该线性算法 H=AF 的矩阵为 A_{i,j}=[i\geq j]G_{i-j}

将其转置,记 H'=A^TF ,则有 :

H_k'=\sum\limits_{i=k}G_{i-k}F_i=\sum\limits_{i=0}G_iF_{i+k}

即为减法卷积。

下面用 \times^T*^T 表示转置后的多项式乘法。(将符号右边进行转置)

考虑多项式 F 和序列 X ,求 F 在每个 X_i 处的取值。记答案为序列 Y

可写出 :

Y_k=F(X_k)=\sum\limits_{i=0}F_iX_k^i

可知该线性算法 Y=AF 的矩阵为 A_{i,j}=X_k^i

将其转置,记 Y'=A^TF ,则有 :

Y_k'=\sum\limits_{i=0}F_iX_i^k

可以看做求一个多项式 G

\begin{aligned} G(x)&=\sum\limits_{k=0}x^k\sum\limits_{i=0}F_iX_i^k\\ &=\sum\limits_{i=0}F_i\sum\limits_{k=0}x^kX_i^k\\ &=\sum\limits_{i=0}\dfrac{F_i}{1-xX_i} \end{aligned}

接下来,我们要找出一个求解上式的高效线性算法。一个经典的算法是 : 分治 FFT,复杂度为 O(n\log^2n)

对于分治区间 [l,r) ,维护分母 Q_{[l,r)}(x)=\prod_{i=l}^r(1-xX_i) 与分子 P_{[l,r)}(x)=\sum_{k=l}^rF_k\prod_{i=l,\ i\neq k}^r(1-xX_i)

在利用 [l,m),[m,r) 的信息计算 [l,r) 时,有 :

\begin{aligned} Q_{[l,r)}&=Q_{[l,m)}Q_{[m,r)}\\ P_{[l,r)}&=P_{[l,m)}Q_{[m,r)}+P_{[m,r)}Q_{[l,m)} \end{aligned}

最终,计算 P_{[0,n)}/Q_{[0,n)} 即可得到答案。

这个线性算法分为两步。

  1. 将分治树的每一层看做一个线性算法,从低到高依次进行(利用多项式乘法)。

  2. 最终,将 \frac{1}{Q_{[0,n)}} 看做由转移矩阵得到的常量,进行多项式乘法。

所以,转置后的算法可以这样实现。

  1. 计算 P'_{[0,n)}=F\times^{T}\frac{1}{Q_{[0,n)}}

  2. 分析原分治的转置,将每一层的转移看做一个线性算法。

    P_{[0,m)},P_{[m,n)} 看做输入,最终得到的 P_{[0,n)} 看做输出。

    简记 Q^L=Q_{[0,m)},Q^R=Q_{[m,n)}

    \begin{aligned} P_{[0,n)}&=P_{[0,m)}Q_R+P_{[m,n)}Q_L\\ \Rightarrow P_k&=\sum\limits_{i=0}^{\min(m-1,k)}P^*_iQ^R_{k-i}+\sum\limits_{i=0}^{\min(n-m-1,k)}P^*_{i+m}Q^L_{k-i} \end{aligned}

    可以观察到该算法的转移矩阵为 A_{k,i}=[0\leq i<m]Q^R_{k-i}+[m\leq i<n]Q^L_{k-i-m}

    故转置后的矩阵为 A^T_{k,i}=[0\leq k<m]Q^R_{i-k}+[m\leq k<n]Q^L_{i-k-m}

    即转置算法的结果为 P' ,写作和式,可得 :

    \begin{aligned} P_k'&=\sum\limits_{i=k}^{m-1}P^*_iQ^R_{i-k}&(0\leq k<m)\\ P_k'&=\sum\limits_{i=k+m}^{n-1}P^*_iQ^L_{i-k-m}&(m\leq k<n) \end{aligned}

    不难发现,两者实际上是转置乘法。即

    \begin{aligned} P'_{[0,m)}=P_{[0,n)}\times ^TQ^R\\ P'_{[m,n)}=P_{[0,n)}\times ^TQ^L \end{aligned}

    有没有更简洁的方法来得到这一结论呢?

    $$P_{[0,m)}\xrightarrow{\times Q^R}P_{[0,n)}\xleftarrow{\times Q^L}P_{[m,n)}$$ 故将贡献流向(以及边内部)取反后,有贡献图 : $$P_{[0,m)}\xleftarrow{\times^TQ^R}P_{[0,n)}\xrightarrow{\times^TQ^L}P_{[m,n)}$$ 此时转移就一目了然了。

最终,分治树叶节点的 P 的常数项即为答案。正如在转置算法中他们作为初始值。

至此,我们得到了一个常数较小,且易于实现(不需要多项式取模)的多点求值算法。

#include<algorithm>
#include<cstring>
#include<cstdio>
#include<vector>
#define ll long long
#define ull unsigned ll
#define clr(f,n) memset(f,0,sizeof(int)*(n))
#define cpy(f,g,n) memcpy(f,g,sizeof(int)*(n))
#define pb push_back
const int _G=3,mod=998244353,Maxn=1<<16|500;
using namespace std; 
ll powM(ll a,ll t=mod-2){
  ll ans=1;
  while(t){
    if(t&1)ans=ans*a%mod;
    a=a*a%mod;t>>=1;
  }return ans;
}
const int invG=powM(_G);
int tr[Maxn<<1],tf;
void tpre(int n){
  if (tf==n)return ;tf=n;
  for(int i=0;i<n;i++)
    tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0);
}
void NTT(int *g,bool op,int n)
{
  tpre(n);
  static ull f[Maxn<<1],w[Maxn];w[0]=1;
  for (int i=0;i<n;i++)f[i]=(((ll)mod<<4)+g[tr[i]])%mod;
  for(int l=1;l<n;l<<=1){
    ull tG=powM(op?_G:invG,(mod-1)/(l+l));
    for (int i=1;i<l;i++)w[i]=w[i-1]*tG%mod;
    for(int k=0;k<n;k+=l+l)
      for(int p=0;p<l;p++){
        int tt=w[p]*f[k|l|p]%mod;
        f[k|l|p]=f[k|p]+mod-tt;
        f[k|p]+=tt;
      }
  }if (!op){
    ull invn=powM(n);
    for(int i=0;i<n;++i)
      g[i]=f[i]%mod*invn%mod;
  }else for(int i=0;i<n;++i)g[i]=f[i]%mod;
}
void px(int *f,int *g,int n)
{for(int i=0;i<n;++i)f[i]=1ll*f[i]*g[i]%mod;}
#define Poly vector<int>
Poly operator + (const Poly &A,const Poly &B)
{
  Poly C=A;C.resize(max(A.size(),B.size()));
  for (int i=0;i<B.size();i++)C[i]=(C[i]+B[i])%mod;
  return C;
}
Poly operator - (const Poly &A,const Poly &B)
{
  Poly C=A;C.resize(max(A.size(),B.size()));
  for (int i=0;i<B.size();i++)C[i]=(C[i]+mod-B[i])%mod;
  return C;
}
Poly operator * (int c,const Poly &A)
{
  Poly C;C.resize(A.size());
  for (int i=0;i<A.size();i++)C[i]=1ll*c*A[i]%mod;
  return C;
}
Poly operator * (const Poly &A,const Poly &B)
{
  static int a[Maxn<<1],b[Maxn<<1];
  if (min(A.size(),B.size())<=40){
    Poly C;C.resize(A.size()+B.size()-1);
    for (int i=0;i<A.size();i++)
      for (int j=0;j<B.size();j++)
        C[i+j]=(C[i+j]+1ll*A[i]*B[j])%mod;
    return C;
  }
  cpy(a,&A[0],A.size());
  cpy(b,&B[0],B.size());
  Poly C;C.resize(A.size()+B.size()-1);
  int n=1;for(n;n<C.size();n<<=1);
  NTT(a,1,n);NTT(b,1,n);
  px(a,b,n);NTT(a,0,n);
  cpy(&C[0],a,C.size());
  clr(a,n);clr(b,n);
  return C;
}
Poly mulT(Poly &A,const Poly &B)
{
  reverse(A.begin(),A.end());
  Poly C=A*B;C.resize(A.size());
  reverse(C.begin(),C.end());reverse(A.begin(),A.end());
  return C;
}  
void inv(const Poly &A,Poly &B,int n)
{
  if (n==1)B.push_back(powM(A[0]));
  else if (n&1){
    inv(A,B,--n);
    int sav=0;
    for (int i=0;i<n;i++)sav=(sav+1ll*B[i]*A[n-i])%mod;
    B.push_back(1ll*sav*powM(mod-A[0])%mod);
  }else {
    inv(A,B,n/2);
    Poly sA;sA.resize(n);
    cpy(&sA[0],&A[0],n);
    B=2*B-B*B*sA;
    B.resize(n);
  }
}
Poly inv(const Poly &A)
{Poly C;inv(A,C,A.size());return C;}
Poly Q[Maxn<<1],P[22];
int sx[Maxn];
void solve1(int l,int r,int u)
{
  if (l==r){
    Q[u].pb(1);
    Q[u].pb(mod-sx[l]);
    return ;
  }int mid=(l+r)>>1;
  solve1(l,mid,u<<1);
  solve1(mid+1,r,u<<1|1);
  Q[u]=Q[u<<1]*Q[u<<1|1];
}
int sy[Maxn];
void solve2(int l,int r,int u,int dep)
{
  if (l==r){sy[l]=P[dep][0];return ;}
  int mid=(l+r)>>1;
  P[dep].resize(r-l+1);
  P[dep+1]=mulT(P[dep],Q[u<<1|1]);
  solve2(l,mid,u<<1,dep+1);
  P[dep+1]=mulT(P[dep],Q[u<<1]);
  solve2(mid+1,r,u<<1|1,dep+1);
}
int n,m;Poly F;
int main()
{
  scanf("%d%d",&n,&m);
    F.resize(++n);
    for (int i=0;i<n;i++)scanf("%d",&F[i]);
    for (int i=0;i<m;i++)scanf("%d",&sx[i]);
    solve1(0,m-1,1);Q[1].resize(max(n,m));
    P[0]=mulT(F,inv(Q[1]));
    solve2(0,m-1,1,0);
    for (int i=0;i<m;i++)
      printf("%d\n",sy[i]);
    return 0;
}

例题

F_i(x)=(1-p_i)+p_ix,\ F(x)=\prod_{i=1}^nF_i ,则 F/F_ia 点积后的系数和即为 f_1

a 看做输入向量, f 看做输出向量,求解的过程是一个线性变换 f=Aa ,其中 A_{k,i}=[x^i](F/F_k)

考虑转置问题 g=A^Ta ,则有 g_k=\sum\limits_{i}[x^k](F/F_i) ,于是问题变为了求 \sum\limits_{i=1}^nF/F_i

这显然也可以利用分治 FFT 求解 :

对于分治区间 [l,r) ,记 Q_{[l,r)}(x)=\prod\limits_{i\in[l,r)}F_i,\ P_{[l,r)}(x)=\sum\limits_{i\in[l,r)}Q_{[l,r)}/F_i

转移和上文中的多点求值类似。故转置后的形式也类似。

评测记录