【多项式】FWT 学习笔记
GIFBMP
·
·
个人记录
一、简介
FWT 是一种用来快速计算形如
C_k=\sum_{i\oplus j=k}A_i\times B_j
的卷积式的算法,其中 \oplus 可以为 \text{and,or,xor} 等位运算。
我们记 (A,B) 表示将 B 接到 A 后面构成的新序列,A+B 表示 A 和 B 按位相加得到的新序列,A\times B 表示 A 和 B 按位相乘得到的新序列。
根据和 FFT 相似的思路,我们可以先对 A,B 进行一次变换,记作 FWT(A),FWT(B),使得 FWT(C)=FWT(A)\times FWT(B)。
显然有:
(FWT(A),FWT(B))\times(FWT(C),FWT(D))=(FWT(A)\times FWT(C),FWT(B)\times FWT(D))
由于 FWT 是一种线性变换,故满足 FWT(A+B)=FWT(A)+FWT(B)。
则 FWT 满足分配律,即:
FWT(A)\times FWT(B+C)=FWT(A)\times FWT(B)+FWT(A)\times FWT(C)
二、or-FWT
首先来讨论 \oplus 为 \text{or} 运算的情况。
我们设数列 A 的长度为 2^n,前 2^{n-1} 个数构成序列 A_0,后 2^{n-1} 个数构成序列 A_1。
那么有:
FWT(A)=(FWT(A_0),FWT(A_0)+FWT(A_1))
证明:
我们利用数学归纳法,在 |A|=|B|=1 时,显然有 FWT(C)=FWT(A)\times FWT(B) 成立。
设 FWT(C_0)=FWT(A_0)\times FWT(B_0) 和 FWT(C_1)=FWT(A_1)\times FWT(B_1) 均成立,记 A|B 表示 A 和 B 的或卷积,那么有:
\begin{aligned}
FWT(C)&=FWT(A|B)\\
&=FWT((A|B)_0,(A|B)_1)\\
&=FWT(A_0|B_0,A_0|B_1+A_1|B_0+A_1|B_1)\\
&=(FWT(A_0|B_0),FWT(A_0|B_0+A_0|B_1+A_1|B_0+A_1|B_1))\\
&=(FWT(A_0)\times FWT(B_0),FWT(A_0)\times FWT(B_0)+FWT(A_0)\times \\
&\ \ \ \ \ FWT(B_1)+FWT(A_1)\times FWT(B_0)+FWT(A_1)\times FWT(B_1))\\
&=(FWT(A_0)\times FWT(B_0),(FWT(A_0)+FWT(A_1))\times (FWT(B_0)+FWT(B_1)))\\
&=(FWT(A_0),FWT(A_0)+FWT(A_1))\times (FWT(B_0),FWT(B_0)+FWT(B_1))\\
&=FWT(A)\times FWT(B)
\end{aligned}
然后对 C 进行一遍 UFWT 即可得到答案。
由于 UFWT 为 FWT 的逆变换,则有:
UFWT(A)=(UFWT(A_0),UFWT(A_1)-UFWT(A_0))
三、and-FWT
当 \oplus 为 \text{and} 时,可以得到:
FWT(A)=(FWT(A_0)+FWT(A_1),FWT(A_1))
证明同理,设 FWT(C_0)=FWT(A_0)\times FWT(B_0) 和 FWT(C_1)=FWT(A_1)\times FWT(B_1) 均成立,记 A\&B 表示 A 和 B 的与卷积,那么有:
\begin{aligned}
FWT(C)&=FWT(A\&B)\\
&=FWT((A\&B)_0,(A\&B)_1)\\
&=FWT(A_0\&B_0+A_0\&B_1+A_1\&B_0,A_1\&B_1)\\
&=(FWT(A_0\&B_0+A_0\&B_1+A_1\&B_0+A_1\&B_1),FWT(A_1\&B_1))\\
&=(FWT(A_0)\times FWT(B_0)+FWT(A_0)\times FWT(B_1)+FWT(A_1)\times \\
&\ \ \ \ \ FWT(B_0)+FWT(A_1)\times FWT(B_1),FWT(A_1)\times FWT(B_1))\\
&=(FWT(A_0)\times FWT(B_0),(FWT(A_0)+FWT(A_1))\times (FWT(B_0)+FWT(B_1)))\\
&=(FWT(A_0)+FWT(A_1),FWT(A_1),)\times (FWT(B_0)+FWT(B_1),FWT(B_1))\\
&=FWT(A)\times FWT(B)
\end{aligned}
同理可得:
UFWT(A)=(UFWT(A_0)-UFWT(A_1),UFWT(A_1))
四、xor-FWT
当 \oplus 为 \text{xor} 时,可以得到:
FWT(A)=(FWT(A_0)+FWT(A_1),FWT(A_0)-FWT(A_1))
证明同理,设 FWT(C_0)=FWT(A_0)\times FWT(B_0) 和 FWT(C_1)=FWT(A_1)\times FWT(B_1) 均成立,记 A\otimes B 表示 A 和 B 的异或卷积,那么有:
\begin{aligned}
FWT(C)&=FWT(A\otimes B)\\
&=(FWT((A\otimes B)_0+(A\otimes B)_1),FWT((A\otimes B)_0-(A\otimes B)_1))\\
&=(FWT(A_0
\otimes B_0+A_1\otimes B_1+A_1\otimes B_0+A_0\otimes B_1),\\
&\ \ \ \ \ FWT(A_0
\otimes B_0+A_1\otimes B_1-A_1\otimes B_0-A_0\otimes B_1))\\
&=((FWT(A_0)+FWT(A_1))\times (FWT(B_0)+FWT(B_1)),\\
&\ \ \ \ \ (FWT(A_0)-FWT(A_1))\times (FWT(B_0)-FWT(B_1)))\\
&=(FWT(A_0)+FWT(A_1),FWT(A_0)-FWT(A_1))\times \\
&\ \ \ \ \ (FWT(B_0)+FWT(B_1),FWT(B_0)-FWT(B_1))\\
&=FWT(A)\times FWT(B)
\end{aligned}
同理可得:
UFWT(A)=(\dfrac{UFWT(A_0)+UFWT(A_1)}{2},\dfrac{UFWT(A_0)-UFWT(A_1)}{2})
Code:
#include <cstdio>
using namespace std ;
const int MAXN = 2e5 + 10 , mod = 998244353 , inv2 = 499122177 ;
int n , a[MAXN] , b[MAXN] , c[MAXN] , ta[MAXN] , tb[MAXN] ;
void FWT1 (int *t , int typ) {
for (int i = 2 ; i <= (1 << n) ; i <<= 1)
for (int p = (i >> 1) , j = 0 ; j < (1 << n) ; j += i)
for (int k = j ; k < j + p ; k++)
if (~typ) t[p + k] = (t[p + k] + t[k]) % mod ;
else t[p + k] = (t[p + k] - t[k] + mod) % mod ;
}
void FWT2 (int *t , int typ) {
for (int i = 2 ; i <= (1 << n) ; i <<= 1)
for (int p = (i >> 1) , j = 0 ; j < (1 << n) ; j += i)
for (int k = j ; k < j + p ; k++)
if (~typ) t[k] = (t[k] + t[k + p]) % mod ;
else t[k] = (t[k] - t[k + p] + mod) % mod ;
}
void FWT3 (int *t , int typ) {
for (int i = 2 ; i <= (1 << n) ; i <<= 1)
for (int p = (i >> 1) , j = 0 ; j < (1 << n) ; j += i)
for (int k = j ; k < j + p ; k++) {
int x = t[k] , y = t[p + k] ;
if (~typ) t[k] = (x + y) % mod , t[p + k] = (x - y + mod) % mod ;
else t[k] = 1LL * (x + y) % mod * inv2 % mod , t[p + k] = 1LL * (x - y + mod) % mod * inv2 % mod ;
}
}
int main () {
scanf ("%d" , &n) ;
for (int i = 0 ; i < (1 << n) ; i++) scanf ("%d" , &a[i]) , ta[i] = a[i] ;
for (int i = 0 ; i < (1 << n) ; i++) scanf ("%d" , &b[i]) , tb[i] = b[i] ;
FWT1 (ta , 1) , FWT1 (tb , 1) ;
for (int i = 0 ; i < (1 << n) ; i++) c[i] = 1LL * ta[i] * tb[i] % mod ;
FWT1 (c , -1) ;
for (int i = 0 ; i < (1 << n) ; i++) printf ("%d " , c[i]) , ta[i] = a[i] , tb[i] = b[i] ;
puts ("") ;
FWT2 (ta , 1) , FWT2 (tb , 1) ;
for (int i = 0 ; i < (1 << n) ; i++) c[i] = 1LL * ta[i] * tb[i] % mod ;
FWT2 (c , -1) ;
for (int i = 0 ; i < (1 << n) ; i++) printf ("%d " , c[i]) , ta[i] = a[i] , tb[i] = b[i] ;
puts ("") ;
FWT3 (ta , 1) , FWT3 (tb , 1) ;
for (int i = 0 ; i < (1 << n) ; i++) c[i] = 1LL * ta[i] * tb[i] % mod ;
FWT3 (c , -1) ;
for (int i = 0 ; i < (1 << n) ; i++) printf ("%d " , c[i]) ;
return 0 ;
}