【多项式】FWT 学习笔记

· · 个人记录

一、简介

FWT 是一种用来快速计算形如

C_k=\sum_{i\oplus j=k}A_i\times B_j

的卷积式的算法,其中 \oplus 可以为 \text{and,or,xor} 等位运算。

我们记 (A,B) 表示将 B 接到 A 后面构成的新序列,A+B 表示 AB 按位相加得到的新序列,A\times B 表示 AB 按位相乘得到的新序列。

根据和 FFT 相似的思路,我们可以先对 AB 进行一次变换,记作 FWT(A)FWT(B),使得 FWT(C)=FWT(A)\times FWT(B)

显然有:

(FWT(A),FWT(B))\times(FWT(C),FWT(D))=(FWT(A)\times FWT(C),FWT(B)\times FWT(D))

由于 FWT 是一种线性变换,故满足 FWT(A+B)=FWT(A)+FWT(B)

则 FWT 满足分配律,即:

FWT(A)\times FWT(B+C)=FWT(A)\times FWT(B)+FWT(A)\times FWT(C)

二、or-FWT

首先来讨论 \oplus\text{or} 运算的情况。

我们设数列 A 的长度为 2^n,前 2^{n-1} 个数构成序列 A_0,后 2^{n-1} 个数构成序列 A_1

那么有:

FWT(A)=(FWT(A_0),FWT(A_0)+FWT(A_1))

证明:

我们利用数学归纳法,在 |A|=|B|=1 时,显然有 FWT(C)=FWT(A)\times FWT(B) 成立。

FWT(C_0)=FWT(A_0)\times FWT(B_0)FWT(C_1)=FWT(A_1)\times FWT(B_1) 均成立,记 A|B 表示 AB 的或卷积,那么有:

\begin{aligned} FWT(C)&=FWT(A|B)\\ &=FWT((A|B)_0,(A|B)_1)\\ &=FWT(A_0|B_0,A_0|B_1+A_1|B_0+A_1|B_1)\\ &=(FWT(A_0|B_0),FWT(A_0|B_0+A_0|B_1+A_1|B_0+A_1|B_1))\\ &=(FWT(A_0)\times FWT(B_0),FWT(A_0)\times FWT(B_0)+FWT(A_0)\times \\ &\ \ \ \ \ FWT(B_1)+FWT(A_1)\times FWT(B_0)+FWT(A_1)\times FWT(B_1))\\ &=(FWT(A_0)\times FWT(B_0),(FWT(A_0)+FWT(A_1))\times (FWT(B_0)+FWT(B_1)))\\ &=(FWT(A_0),FWT(A_0)+FWT(A_1))\times (FWT(B_0),FWT(B_0)+FWT(B_1))\\ &=FWT(A)\times FWT(B) \end{aligned}

然后对 C 进行一遍 UFWT 即可得到答案。

由于 UFWT 为 FWT 的逆变换,则有:

UFWT(A)=(UFWT(A_0),UFWT(A_1)-UFWT(A_0))

三、and-FWT

\oplus\text{and} 时,可以得到:

FWT(A)=(FWT(A_0)+FWT(A_1),FWT(A_1))

证明同理,设 FWT(C_0)=FWT(A_0)\times FWT(B_0)FWT(C_1)=FWT(A_1)\times FWT(B_1) 均成立,记 A\&B 表示 AB 的与卷积,那么有:

\begin{aligned} FWT(C)&=FWT(A\&B)\\ &=FWT((A\&B)_0,(A\&B)_1)\\ &=FWT(A_0\&B_0+A_0\&B_1+A_1\&B_0,A_1\&B_1)\\ &=(FWT(A_0\&B_0+A_0\&B_1+A_1\&B_0+A_1\&B_1),FWT(A_1\&B_1))\\ &=(FWT(A_0)\times FWT(B_0)+FWT(A_0)\times FWT(B_1)+FWT(A_1)\times \\ &\ \ \ \ \ FWT(B_0)+FWT(A_1)\times FWT(B_1),FWT(A_1)\times FWT(B_1))\\ &=(FWT(A_0)\times FWT(B_0),(FWT(A_0)+FWT(A_1))\times (FWT(B_0)+FWT(B_1)))\\ &=(FWT(A_0)+FWT(A_1),FWT(A_1),)\times (FWT(B_0)+FWT(B_1),FWT(B_1))\\ &=FWT(A)\times FWT(B) \end{aligned}

同理可得:

UFWT(A)=(UFWT(A_0)-UFWT(A_1),UFWT(A_1))

四、xor-FWT

\oplus\text{xor} 时,可以得到:

FWT(A)=(FWT(A_0)+FWT(A_1),FWT(A_0)-FWT(A_1))

证明同理,设 FWT(C_0)=FWT(A_0)\times FWT(B_0)FWT(C_1)=FWT(A_1)\times FWT(B_1) 均成立,记 A\otimes B 表示 AB 的异或卷积,那么有:

\begin{aligned} FWT(C)&=FWT(A\otimes B)\\ &=(FWT((A\otimes B)_0+(A\otimes B)_1),FWT((A\otimes B)_0-(A\otimes B)_1))\\ &=(FWT(A_0 \otimes B_0+A_1\otimes B_1+A_1\otimes B_0+A_0\otimes B_1),\\ &\ \ \ \ \ FWT(A_0 \otimes B_0+A_1\otimes B_1-A_1\otimes B_0-A_0\otimes B_1))\\ &=((FWT(A_0)+FWT(A_1))\times (FWT(B_0)+FWT(B_1)),\\ &\ \ \ \ \ (FWT(A_0)-FWT(A_1))\times (FWT(B_0)-FWT(B_1)))\\ &=(FWT(A_0)+FWT(A_1),FWT(A_0)-FWT(A_1))\times \\ &\ \ \ \ \ (FWT(B_0)+FWT(B_1),FWT(B_0)-FWT(B_1))\\ &=FWT(A)\times FWT(B) \end{aligned}

同理可得:

UFWT(A)=(\dfrac{UFWT(A_0)+UFWT(A_1)}{2},\dfrac{UFWT(A_0)-UFWT(A_1)}{2})

Code:

#include <cstdio>
using namespace std ;
const int MAXN = 2e5 + 10 , mod = 998244353 , inv2 = 499122177 ;
int n , a[MAXN] , b[MAXN] , c[MAXN] , ta[MAXN] , tb[MAXN] ;
void FWT1 (int *t , int typ) {
    for (int i = 2 ; i <= (1 << n) ; i <<= 1)
        for (int p = (i >> 1) , j = 0 ; j < (1 << n) ; j += i)
            for (int k = j ; k < j + p ; k++)
                if (~typ) t[p + k] = (t[p + k] + t[k]) % mod ;
                else t[p + k] = (t[p + k] - t[k] + mod) % mod ;
}
void FWT2 (int *t , int typ) {
    for (int i = 2 ; i <= (1 << n) ; i <<= 1)
        for (int p = (i >> 1) , j = 0 ; j < (1 << n) ; j += i)
            for (int k = j ; k < j + p ; k++)
                if (~typ) t[k] = (t[k] + t[k + p]) % mod ;
                else t[k] = (t[k] - t[k + p] + mod) % mod ;
}
void FWT3 (int *t , int typ) {
    for (int i = 2 ; i <= (1 << n) ; i <<= 1)
        for (int p = (i >> 1) , j = 0 ; j < (1 << n) ; j += i)
            for (int k = j ; k < j + p ; k++) {
                int x = t[k] , y = t[p + k] ;
                if (~typ) t[k] = (x + y) % mod , t[p + k] = (x - y + mod) % mod ;
                else t[k] = 1LL * (x + y) % mod * inv2 % mod , t[p + k] = 1LL * (x - y + mod) % mod * inv2 % mod ;
            }
}
int main () {
    scanf ("%d" , &n) ;
    for (int i = 0 ; i < (1 << n) ; i++) scanf ("%d" , &a[i]) , ta[i] = a[i] ;
    for (int i = 0 ; i < (1 << n) ; i++) scanf ("%d" , &b[i]) , tb[i] = b[i] ;
    FWT1 (ta , 1) , FWT1 (tb , 1) ;
    for (int i = 0 ; i < (1 << n) ; i++) c[i] = 1LL * ta[i] * tb[i] % mod ;
    FWT1 (c , -1) ;
    for (int i = 0 ; i < (1 << n) ; i++) printf ("%d " , c[i]) , ta[i] = a[i] , tb[i] = b[i] ;
    puts ("") ;
    FWT2 (ta , 1) , FWT2 (tb , 1) ;
    for (int i = 0 ; i < (1 << n) ; i++) c[i] = 1LL * ta[i] * tb[i] % mod ;
    FWT2 (c , -1) ;
    for (int i = 0 ; i < (1 << n) ; i++) printf ("%d " , c[i]) , ta[i] = a[i] , tb[i] = b[i] ;
    puts ("") ;
    FWT3 (ta , 1) , FWT3 (tb , 1) ;
    for (int i = 0 ; i < (1 << n) ; i++) c[i] = 1LL * ta[i] * tb[i] % mod ;
    FWT3 (c , -1) ;
    for (int i = 0 ; i < (1 << n) ; i++) printf ("%d " , c[i]) ;
    return 0 ;
}