我也要学这种东西吗?(多项式)
wangxx2012
·
2026-04-15 18:05:14
·
算法·理论
2026/4/28:添加了多项式多点求值讲解。
2026/5/10:添加了多项式反三角函数讲解。
2026/5/13:反三角函数求导公式出锅,改,并做讲解。
2026/5/14:添加了多项式 k 次方根、多项式三角函数讲解。
2026/5/15:添加了多项式求导、多项式积分讲解。
2026/5/16:优化了排版。
2026/5/17:添加对 n 次剩余(模素数版)的讲解。
快速傅里叶变换 FFT
模板题
梦开始的地方。
点表示法
首先,我们要了解到多项式的性质:
用任意 n+1 个不同点,均可唯一确定一个 n 次多项式 A\left(x\right)=a_0+a_1x+a_2x^2+\cdots+a_nx^n 。
:::info[证明]
我们令这 n+1 个点对为 \left(x_1,y_1\right)\sim \left(x_{n+1},y_{n+1}\right) ,则有:
\begin{cases}
a_0+a_1x_1+...+a_nx_1^n=y_1\\
a_0+a_1x_2+...+a_nx_2^n=y_2\\
...\\
a_0+a_1x_n+...+a_nx_n^n=y_n\\
a_0+a_1x_{n+1}+...+a_nx_{n+1}^n=y_{n+1}\\
\end{cases}
注意到这是一个 n+1 元一次多项式。
那么它有唯一解,即等价于它的系数矩阵的秩是满秩的,即等价于它的系数矩阵的行列式是不等于 0 的。
这里给出它的行列式:
\begin{pmatrix}
1 & x_1 & x_1^2 &...&x_1^n\\
1 & x_2 & x_2^2 &...&x_2^n\\
...\\
1 & x_{n+1} & x_{n+1}^2 &...&x_{n+1}^n\\
\end{pmatrix}
这个行列式的值等于 n+1 阶范德蒙行列式的值。
即等于
\prod_{1\le i<j\le n+1}\left(x_i-x_j\right)
因为 x_1\sim x_{n+1} 的值互不相同,则上述式子不等于 0 。
命题得证。
:::
由于这个性质,我们即可将一个多项式的系数表示法转换为点表示法。
而 FFT 是取 n 个特殊点进行对原多项式的变换。
单位根
我们引入复数。
:::info[复数的加法]
对于两个复数 a+bi 以及 c+di ,则有
\left(a+bi\right)+\left(c+di\right)=\left(a+c\right)+\left(b+d\right)i
:::
:::info[复数的乘法]
对于两个复数 a+bi 以及 c+di ,则有
\left(a+bi\right)\times\left(c+di\right)=\left(ac-bd\right)+\left(ad+bc\right)i
:::
再引入复数意义下的 n 次单位根。
复数意义下的 n 次单位根即在复平面上把单位圆 n 等分。我们把它记作 \omega_n^k ,表示在复平面上把单位圆 n 等分中,从 \left(1,0\right) 开始,逆时针转 k 个 \frac{1}{n} 圈落在单位圆上的那个点。
则它有以下性质:
\omega_n^i \neq \omega_n^j,\forall i\neq j
\omega_n^k=\cos \frac{2k\pi}{n}+i\sin \frac{2k\pi}{n}
\omega_n^0=\omega_n^n=1
\omega_{2n}^{2k}=\omega_n^k
\omega_{n}^{k+\frac{n}{2}}=-\omega_n^k
正变换
这一步是将系数表示法转换为点表示法 ,即我们将原先的 n-1 次多项式 A\left(x\right) 转换为由 n 个特殊点对 \left(\omega_n^k,A\left(\omega_n^k\right)\right) 表示。则目标变为如何快速求解 A(\omega_n^k) 。
我们先将 A\left(x\right)=a_0+a_1x+a_2x^2+...a_nx^n 进行按照系数下标的奇偶性进行分类。
则有
A_1\left(x\right)=a_0+a_2x+a_4x^2+...+a_{n-2}x^{\frac{n}{2}-1}\\
A_2\left(x\right)=a_1+a_3x+a_5x^2+...+a_{n-1}x^{\frac{n}{2}-1}
容易发现,A\left(x\right)=A_1\left(x^2\right)+xA_2\left(x^2\right) 。
当 k \isin [0,\frac{n}{2}-1] 时,则有 A\left(\omega_n^k\right)=A_1\left(\omega_n^{2k}\right)+\omega_n^kA_2\left(\omega_n^{2k}\right) 。
那么
\begin{aligned}
A\left(\omega_n^k\right)&=A_1\left(\omega_n^{2k}\right)+\omega_n^kA_2\left(\omega_n^{2k}\right)\\
&=A_1\left(\omega_{\frac{n}{2}}^k\right)+\omega_n^kA_2\left(\omega_{\frac{n}{2}}^k\right)
\end{aligned}
当 k \isin [\frac{n}{2},n-1] 时,则有 A\left(\omega_n^{k+\frac{n}{2}}\right)=A_1\left(\omega_n^{2k+n}\right)+\omega_n^kA_2\left(\omega_n^{2k+n}\right) 。
那么
\begin{aligned}
A\left(\omega_n^{k+\frac{n}{2}}\right)&=A_1\left(\omega_n^{2k+n}\right)+\omega_n^{k+\frac{n}{2}}A_2\left(\omega_n^{2k+n}\right)\\
&=A_1\left(\omega_n^{2k}\right)+\omega_n^{k+\frac{n}{2}}A_2\left(\omega_n^{2k}\right)\\
&=A_1\left(\omega_{\frac{n}{2}}^k\right)-\omega_n^kA_2\left(\omega_{\frac{n}{2}}^k\right)
\end{aligned}
容易发现,当我们求 A\left(\omega_n^k\right) 时,可以运用分治法先求出 A_1\left(\omega_{\frac{n}{2}}^k\right) 和 A_2\left(\omega_{\frac{n}{2}}^k\right) ,然后分讨 k 的范围,套用以上式子求解即可。
由于使用分治,则时间复杂度优化到 O\left(n\log n\right) 。
逆变换
这一步是将点表示法转换为系数表示法 。
为了方便书写,我们将原先的点对 \left(\omega_n^k,A\left(\omega_n^k\right)\right) 写为 \left(x_k,y_k\right) 。
对于由这 n+1 个特殊点对表示的多项式的每一项的系数 c_k ,有
c_k=\frac{1}{n}\sum_{i=0}^{n-1}y_i\left(\omega_{n}^{-k}\right)^i。
:::info[证明]
考虑倒推
\begin{aligned}
c_k&=\frac{1}{n}\sum_{i=0}^{n-1}y_i\left(\omega_{n}^{-k}\right)^i\\
&=\frac{1}{n}\sum_{i=0}^{n-1}\left(\sum_{j=0}^{n-1}a_j\left(\omega_{n}^{i}\right)^j\right)\left(\omega_{n}^{-k}\right)^i\\
&=\frac{1}{n}\sum_{i=0}^{n-1}\left(\sum_{j=0}^{n-1}a_j\left(\omega_n^{j-k}\right)^i\right)\\
&=\frac{1}{n}\sum_{j=0}^{n-1}a_j\left(\sum_{i=0}^{n-1}\left(\omega_{n}^{j-k}\right)^i\right)
\end{aligned}
考虑构造一个多项式 S\left(x\right)=1+x+x^2+..+x^{n-1} 。
当 k\neq0 时:
S\left(\omega_{n}^{k}\right)=1+\omega_{n}^{k}+\omega_n^{2k}+...+\omega_{n}^{\left(n-1\right)k}
则
\omega_{n}^kS\left(\omega_{n}^{k}\right)=\omega_{n}^k+\omega_n^{2k}+...\omega_{n}^0
则
\left(1-\omega_{n}^k\right)S\left(\omega_{n}^{k}\right)=0
由于 1-\omega_{n}^k\neq0 ,则 S\left(\omega_n^k\right)=0 。
当 k=0 时,S\left(\omega_n^k\right)=S\left(1\right)=n 。
则
\begin{aligned}
&\frac{1}{n}\sum_{j=0}^{n-1}a_j\left(\sum_{i=0}^{n-1}\left(\omega_{n}^{j-k}\right)^i\right)\\
&=\frac{1}{n}\times na_k\\
&=a_k
\end{aligned}
因此 c_k = a_k ,即我们通过逆变换恢复了原系数。
命题得证。
:::
有了以上结论,我们考虑如何快速求解。
令 b_k=\sum_{i=0}^{n-1}y_i\left(\omega_{n}^{-k}\right)^i 以及一个多项式 B\left(x\right)=y_0+y_1x+\cdots+y_{n-1}x^{n-1} 。
则 b_k=B\left(\omega_{n}^{-k}\right) ,容易发现求解这个式子可以套用正变换的方法,这里不多加赘述。
最后的 c_k=\frac{1}{n}b_k 。
二进制翻转
经过上面的推导,我们知道 FFT 需要使用分治,即上文所说的按照系数下标的奇偶性进行分类。
为了实现迭代,我们需要考虑如何正确地实现分类。
先看下面的例子:
最上面的数字表示最开始位置下标的二进制,最下面的数字表示最后位置下标的二进制。
容易发现,从最开始到分治结束,二进制进行了翻转。
那么对于第 i 个二进制数翻转后的结果 \text{rev}_i ,有如下递推式:
\text{rev}_i = \left\lfloor \frac{\text{rev}_{\lfloor \frac{i}{2} \rfloor}}{2} \right\rfloor + \left(i \bmod 2\right)\times 2^{\text{bit}-1}
其中,\text{bit} 表示二进制位数。
代码实现如下:
while((1<<bit)<n+m+1) bit++;
tot=1<<bit;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
蝶形运算
在 FFT 的分治合并过程中,我们利用以下两个公式同时计算出两个点的值:
A\left(\omega_n^k\right)=A_1\left(\omega_{\frac{n}{2}}^k\right)+\omega_n^kA_2\left(\omega_{\frac{n}{2}}^k\right)\\
A\left(\omega_n^{k+\frac{n}{2}}\right)=A_1\left(\omega_{\frac{n}{2}}^k\right)-\omega_n^kA_2\left(\omega_{\frac{n}{2}}^k\right)
其中 k = 0, 1, \dots, \frac{n}{2}-1 。这两个公式构成一次“蝶形运算”:
入:A_1 和 A_2 在相同位置 k 上的值,以及旋转因子 \omega_n^k 。
出:两个新值 A\left(\omega_n^k\right) 和 A\left(\omega_n^{k+\frac{n}{2}}\right) 。
而在迭代实现中,我们可以按照长度从小到大逐层合并。对于当前层长度 \text{mid} (即子问题长度),将整个序列划分为若干个长度为 2\times \text{mid} 的块,在每个块内执行 \text{mid} 次蝶形运算。实现如下:
for(int j=0;j<mid;j++,wk=wk*w1){
auto x=a[i+j],y=wk*a[i+j+mid];
a[i+j]=x+y,a[i+j+mid]=x-y;
}
配合上面的二进制翻转操作,我们可以成功地将 FFT 从递归转为迭代。
:::success[Code]
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=3e6+10;
const double pi=acos(-1);
int n,m;
struct node{
double x,y;
node operator+ (const node& t)const{
return {x+t.x,y+t.y};
}
node operator- (const node& t)const{
return {x-t.x,y-t.y};
}
node operator* (const node& t)const{
return {x*t.x-y*t.y,x*t.y+y*t.x};
}
}a[maxn],b[maxn];
int rev[maxn],bit,tot;
void FFT(node a[],int inv){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
auto w1=node({cos(pi/mid),inv*sin(pi/mid)});
for(int i=0;i<tot;i+=mid*2){
auto wk=node({1,0});
for(int j=0;j<mid;j++,wk=wk*w1){
auto x=a[i+j],y=wk*a[i+j+mid];
a[i+j]=x+y,a[i+j+mid]=x-y;
}
}
}
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n>>m;
for(int i=0;i<=n;i++) cin>>a[i].x;
for(int i=0;i<=m;i++) cin>>b[i].x;
while((1<<bit)<n+m+1) bit++;
tot=1<<bit;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
FFT(a,1),FFT(b,1);
for(int i=0;i<tot;i++) a[i]=a[i]*b[i];
FFT(a,-1);
for(int i=0;i<=n+m;i++) cout<<(int)round(a[i].x/tot)<<" ";
return 0;
}
:::
快速数论变换 NTT
由于你谷并没有专门的 NTT 模版,这里用 FFT 的模板代替。(原因)
模板题
为了能够使得取模后不爆精度,考虑找到某个东西能满足单位根的性质来代替单位根。
那么,原根 就登上了多项式的舞台。
给出两个正整数数 m,p ,满足 m 和 p 互质,且 p>1 。
若有 n 满足 m^n\equiv 1\pmod p 且 n 最小,则称 n 为 m 模 p 意义下的阶,记作 \delta_p(m)=n 。
若 p 为素数,g 是模 p 意义下的原根,则满足 \gcd(g,p)=1 且 \delta_p(g)=p-1 ,且 \{g^{i}\bmod p\mid i=0,1,\dots p-2\}=\{1,2,\dots,p-1\} 。
令 \omega = g^{\frac{(p-1)}{n}} \bmod p ,其中 n \mid p-1 。
则
\omega^n = g^{p-1} \equiv 1 \pmod{p}
且若 \omega^k \equiv 1 ,则 g^{\frac{(p-1)k}{n}} \equiv 1 ,故 p-1 \mid \frac{(p-1)k}{n} ,即 n \mid k 。
因此 \operatorname{ord}_p(\omega) = n ,即 \omega 是模 p 下的 n 次本原单位根。
所以 \omega_n \equiv g^{\frac{p-1}{n}}\pmod p 。
即我们可将单位根替换为原根。
:::success[Code]
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=3e6+10;
const int P=998244353,G=3;
int n,m;
int a[maxn],b[maxn];
int qpow(int x,int k){
int base=1;
while(k){
if(k&1) base=(base*x)%P;
x=(x*x)%P;
k>>=1;
}
return base%P;
}
const int Gi=qpow(G,P-2);
int rev[maxn],bit,tot;
void NTT(int a[],int type){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
int wi;
if(type==1) wi=qpow(G,(P-1)/(mid<<1));
else wi=qpow(Gi,(P-1)/(mid<<1));
for(int i=0;i<tot;i+=mid*2){
int wk=1;
for(int j=0;j<mid;j++,wk=wk*wi%P){
int x=a[i+j],y=wk*a[i+j+mid]%P;
a[i+j]=(x+y)%P,a[i+j+mid]=(x-y+P)%P;
}
}
}
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n>>m;
for(int i=0;i<=n;i++) cin>>a[i],a[i]%=P;
for(int i=0;i<=m;i++) cin>>b[i],b[i]%=P;
while((1<<bit)<n+m+1) bit++;
tot=1<<bit;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
NTT(a,1),NTT(b,1);
for(int i=0;i<tot;i++) a[i]=(a[i]*b[i])%P;
NTT(a,-1);
int inv=qpow(tot,P-2);
for(int i=0;i<=n+m;i++) cout<<a[i]*inv%P<<" ";
return 0;
}
:::
值得注意的是,普通 NTT 的模数的必须形如 a\times2^k+1 ,我们称之为友好模数。
任意模数快速傅里叶变换 MTT
模板题
拆系数 FFT
由于取模,FFT 会爆精度。
而 NTT 又要求模数是友好模数,在这道题上不适用。
考虑改良原来的 FTT。
我们将每个系数拆分为低 15 位和高 15 位。
即
A(x)=A_1(x)+2^{15}\times A_2(x)\\
B(x)=B_1(x)+2^{15}\times B_2(x)
那么有
A\times B=A_1B_1+2^{15}\times(A_1B_2+A_2B_1)+2^{30}\times A_2B_2
然后做 7 遍 FFT 即可。
:::success[Code]
#include<bits/stdc++.h>
#define int long long
#define double long double
#define endl '\n'
using namespace std;
const int maxn=1e6+10;
const double pi=acos(-1);
int P;
int n,m;
int a[maxn],b[maxn],c[maxn];
struct node {
double x,y;
node operator+(const node&t)const {
return{x+t.x,y+t.y};
}
node operator-(const node&t)const {
return{x-t.x,y-t.y};
}
node operator*(const node&t)const {
return{x*t.x-y*t.y,x*t.y+y*t.x};
}
};
int rev[maxn],bit,tot;
node buf1[maxn],buf2[maxn],buf3[maxn],buf4[maxn];
void FFT(node a[],int inv) {
for(int i=0;i<tot;i++) {
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1) {
node w1={cos(pi/mid),inv*sin(pi/mid)};
for(int i=0;i<tot;i+=mid*2) {
node wk={1,0};
for(int j=0;j<mid;j++,wk=wk*w1) {
node x=a[i+j],y=wk*a[i+j+mid];
a[i+j]=x+y;
a[i+j+mid]=x-y;
}
}
}
if(inv==-1) {
for(int i=0;i<tot;i++) {
a[i].x/=tot;
a[i].y/=tot;
}
}
}
int get_val(double x){
return ((long long)(round(x)))%P;
}
void MTT(int c[],int len){
bit=0;
while((1<<bit)<len) bit++;
tot=1<<bit;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<tot;i++) buf1[i]=buf2[i]=buf3[i]=buf4[i]={0,0};
for(int i=0;i<n;i++){
buf1[i]={(double)(a[i]&32767),0};
buf2[i]={(double)(a[i]>>15),0};
}
for(int i=0;i<m;i++){
buf3[i]={(double)(b[i]&32767),0};
buf4[i]={(double)(b[i]>>15),0};
}
FFT(buf1,1); FFT(buf2,1); FFT(buf3,1); FFT(buf4,1);
for(int i=0;i<tot;i++){
node tmp1=buf1[i]*buf3[i],tmp2=buf1[i]*buf4[i],tmp3=buf2[i]*buf3[i],tmp4=buf2[i]*buf4[i];
buf1[i]=tmp1;
buf2[i]=tmp2+tmp3;
buf3[i]=tmp4;
}
FFT(buf1,-1); FFT(buf2,-1); FFT(buf3,-1);
for(int i=0;i<len;i++){
int low=get_val(buf1[i].x),mid=get_val(buf2[i].x),high=get_val(buf3[i].x);
c[i]=(low+(mid<<15)+(high<<30))%P;
}
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n>>m>>P;
n++; m++;
for(int i=0;i<n;i++) cin>>a[i],a[i]%=P;
for(int i=0;i<m;i++) cin>>b[i],b[i]%=P;
MTT(c,n+m-1);
for(int i=0;i<n+m-1;i++) cout<<c[i]<<" ";
return 0;
}
:::
三模 NTT
:::info[exgcd 的讲解]
exgcd 是用来求方程
ax+by=\gcd(a,b)
的一组整数解的。
我们假设已经知道了方程
bx+\left(a\bmod b\right)y=\gcd(b,a\bmod b)=\gcd(a,b)
的一组整数解 (x',y') 。
又有
a\bmod b=a-\lfloor\frac{a}{b}\rfloor\times b
我们令 q=\lfloor\frac{a}{b}\rfloor ,则有
bx'+\left(a-qb\right)y'=\gcd(a,b)
整理,得
ay'+b(x'-qy')=\gcd(a,b)
对比我们想要求解的方程,我们惊喜地发现:
\begin{cases}
x=y'\\y=x'-qy'
\end{cases}
在代码中递归求解即可。
int exgcd(int a,int b,int &x,int &y){
if(!b){
x=1; y=0;
return a;
}
int gcd=exgcd(b,a%b,x,y);
int t=x;
x=y; y=t-a/b*y;
return gcd;
}
:::
:::info[CRT 的讲解]
模板题
CRT 用来求解如下同余方程组:
\begin{cases}
x\equiv b_1\pmod{a_1}\\
x\equiv b_2\pmod{a_2}\\
\ldots\\
x\equiv b_n\pmod{a_n}
\end{cases}
考虑拆成如下几个方程组
\begin{cases}
x_1\equiv b_1\pmod{a_1}\\
x_1\equiv 0\pmod{a_2}\\
\ldots\\
x_1\equiv 0\pmod{a_n}
\end{cases},
\begin{cases}
x_2\equiv 0\pmod{a_1}\\
x_2\equiv b_2\pmod{a_2}\\
\ldots\\
x_2\equiv 0\pmod{a_n}
\end{cases},\ldots,
\begin{cases}
x_n\equiv 0\pmod{a_1}\\
x_n\equiv 0\pmod{a_2}\\
\ldots\\
x_n\equiv b_n\pmod{a_n}
\end{cases}
则 x=\sum_{i=1}^nx_i 。
然后令 y_i=\frac{x_i}{b_i} ,则有
\begin{cases}
y_1\equiv 1\pmod{a_1}\\
y_1\equiv 0\pmod{a_2}\\
\ldots\\
y_1\equiv 0\pmod{a_n}
\end{cases},
\begin{cases}
y_2\equiv 0\pmod{a_1}\\
y_2\equiv 1\pmod{a_2}\\
\ldots\\
y_2\equiv 0\pmod{a_n}
\end{cases},\ldots,
\begin{cases}
y_n\equiv 0\pmod{a_1}\\
y_n\equiv 0\pmod{a_2}\\
\ldots\\
y_n\equiv 1\pmod{a_n}
\end{cases}
那么想让 y_i\equiv0\pmod{a_j},j\neq i ,y_i 肯定为 \prod_{j\neq i}a_j 的倍数,则令 y_i=k_i\prod_{j\neq i}a_j ,则有
k_i\prod_{j\neq i}a_j\equiv1\pmod{a_i}
注意到这个东西用逆元求解即可。
但如何算出最小解呢?下面的话摘自 wsy_I 大神的学习笔记:
注意观察一个性质:假如这个 x 加上或者减去 \prod^{n}_{i=1}a_i ,x 对于所有的 a_i 取模之后的结果是相同的。如果加上或者减去的值不同于 \prod^{n}_{i=1}a_i ,那么必定有某些取模之后的结果不同。于是构造最小解可以直接对于 \prod^{n}_{i=1}a_i 取模。
温馨提示:模板题会爆 long long。
:::success[Code]
#include<bits/stdc++.h>
#define int __int128
#define endl '\n'
using namespace std;
const int maxn=10+5;
void write_(int x){
if(!x)return;
write_(x/10);
putchar(x%10+'0');
}
void write(int x){
if(!x)putchar('0');
else{
if(x<0) putchar('-'),x=-x;
write_(x);
}
}
inline int read(){
int x=0,y=0;char c;
while(!isdigit(c))y|=c=='-',c=getchar();
while(isdigit(c))x=(x<<3)+(x<<1)+(c^48),c=getchar();
return y?-x:x;
}
void exgcd(int a,int b,int &x,int &y){
if(!b){
x=1; y=0;
return;
}
exgcd(b,a%b,x,y);
int t=x;
x=y; y=t-a/b*y;
return;
}
int n,a[maxn],b[maxn];
int ans,mod=1;
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
n=read();
for(int i=1;i<=n;i++) a[i]=read(),b[i]=read(),mod*=a[i];
for(int i=1;i<=n;i++){
int mul=1;
for(int j=1;j<=n;j++){
if(i==j) continue;
mul*=a[j];
}
int kx,ky;
exgcd(mul,a[i],kx,ky);
kx=(kx%a[i]+a[i])%a[i];
int y=kx*mul,x=y*b[i];
ans=(ans+x)%mod;
}
write(ans);
return 0;
}
:::
:::info[原理]{open}
三模 NTT 是通过选取三个友好模数 m_1,m_2,m_3 ,再用 NTT 计算 F(x)G(x) 在模 m_i 意义下的答案,最后用 CRT 恢复精确系数。
:::
这里选取的三个模数为:
\begin{aligned}
&m_1=998244353=119\times2^{23}+1\\
&m_2=469762049=7\times2^{26}+1\\
&m_3=1004535809=479\times2^{21}+1
\end{aligned}
优点是它们的原根都为 3 。
接下来考虑如何合并取模后的答案。我们使用 CRT。
对于同余方程组
\begin{cases}
x\equiv r_1\pmod{m_1}\\
x\equiv r_2\pmod{m_2}
\end{cases}
其中 m_1,m_2 互质。
则有
x=r_1+m_1\times\left(\left(t\times\left(r_2-r_1\right)\right)\bmod m_2\right)
其中 t 为 m_1 模 m_2 意义下的逆元。
按照这样将 m_1,m_2,m_3 合并,我们得到的 x 就是该系数的真实整数值(在 0\le x<m_1\times m_2\times m_3 范围内)。
最后的答案输出真实系数值再对题目要求的 p 取模即可。
:::success[Code]
#include<bits/stdc++.h>
#define int __int128
#define endl '\n'
using namespace std;
const int maxn=3e6+10;
const int mod1=998244353,mod2=(1<<26)*7+1,mod3=(1<<21)*479+1,G=3;
int a[maxn],b[maxn];
void write_(int x){
if(!x)return;
write_(x/10);
putchar(x%10+'0');
}
void write(int x){
if(!x)putchar('0');
else{
if(x<0) putchar('-'),x=-x;
write_(x);
}
}
inline int read(){
int x=0,y=0;char c=getchar();
while(!isdigit(c))y|=c=='-',c=getchar();
while(isdigit(c))x=(x<<3)+(x<<1)+(c^48),c=getchar();
return y?-x:x;
}
int qpow(int x,int k,int mod){
int base=1;
while(k){
if(k&1) base=(base*x)%mod;
x=(x*x)%mod;
k>>=1;
}
return base%mod;
}
void NTT(int a[],int rev[],int type,int tot,int mod){
int Gi=qpow(G,mod-2,mod);
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
int wi;
if(type==1) wi=qpow(G,(mod-1)/(mid<<1),mod);
else wi=qpow(Gi,(mod-1)/(mid<<1),mod);
for(int i=0;i<tot;i+=mid*2){
int wk=1;
for(int j=0;j<mid;j++,wk=wk*wi%mod){
int x=a[i+j],y=wk*a[i+j+mid]%mod;
a[i+j]=(x+y)%mod,a[i+j+mid]=(x-y+mod)%mod;
}
}
}
}
void exgcd(int a,int b,int &x,int &y){
if(!b){
x=1; y=0;
return;
}
exgcd(b,a%b,x,y);
int t=x;
x=y; y=t-a/b*y;
return;
}
int rev[maxn],bit,tot;
int n,m,p;
void MUL(int f[],int g[],int len1,int len2,int mod){
memset(rev,0,sizeof(rev)); bit=tot=0;
while((1<<bit)<len1+len2+1) bit++;
tot=1<<bit;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
NTT(f,rev,1,tot,mod),NTT(g,rev,1,tot,mod);
for(int i=0;i<tot;i++) f[i]=(f[i]*g[i])%mod;
NTT(f,rev,-1,tot,mod);
int inv=qpow(tot,mod-2,mod);
for(int i=0;i<=len1+len2;i++) f[i]=f[i]*inv%mod;
}
int CRT(int r1,int m1,int r2,int m2){
int k,t;
exgcd(m1,m2,k,t);
k=(k%m2+m2)%m2;
int d=((r2-r1)%m2+m2)%m2;
return r1+m1*((d*k)%m2);
}
int tmpa1[maxn],tmpa2[maxn],tmpa3[maxn];
int tmpb1[maxn],tmpb2[maxn],tmpb3[maxn];
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
n=read(),m=read(),p=read();
for(int i=0;i<=n;i++) a[i]=read(),tmpa1[i]=tmpa2[i]=tmpa3[i]=a[i];
for(int i=0;i<=m;i++) b[i]=read(),tmpb1[i]=tmpb2[i]=tmpb3[i]=b[i];
MUL(tmpa1,tmpb1,n,m,mod1);
MUL(tmpa2,tmpb2,n,m,mod2);
MUL(tmpa3,tmpb3,n,m,mod3);
int len=n+m+1;
for(int i=0;i<len;i++){
int r1=tmpa1[i],r2=tmpa2[i],r3=tmpa3[i];
int x12=CRT(r1,mod1,r2,mod2);
int mod12=mod1*mod2;
int x=CRT(x12,mod12,r3,mod3);
write(x%p),putchar(' ');
}
return 0;
}
:::
温馨提示:计算过程中会爆 long long。
快速沃尔什变换 FWT
模板题
核心思想
寻找一种可逆线性变换 T ,使得
T(C)=T(A)\times T(B)
然后通过逆变换得到 C=T^{-1}\left(T(A)\times T(B)\right) 。
按位或
即求解 h_i=\sum_{k\oplus j=i}f_jg_k ,其中 \oplus 表示按位或。
正变换
对于任意一个二进制数 S ,我们令
F_S=\sum_{T\subseteq S}f_T\\
G_S=\sum_{T\subseteq S}g_T\\
H_S=\sum_{T\subseteq S}h_T
则有
F_S\times G_S=\left(\sum_{A\subseteq S}f_A\right)\left(\sum_{B\subseteq S}g_B\right)
我们注意到展开后,f_Ag_B 会出现在某个集合 C=A\cup B 中,而 C\subseteq S ,则有
F_S\times G_S=\sum_{C\subseteq S}\left(\sum_{A\cup B=C}f_Ag_B\right)
因为 \cup 和按位或是等价的,所以括号里求的就是 h_C ,即
F_S\times G_S=\sum_{C\subseteq S}h_C\\
F_S\times G_S=H_S
至此,我们便得到了
H_S=F_S\times G_S
逆变换
我们观察
H_S=\sum_{T\subseteq S}h_T
并不是容易地发现这是个高维前缀和,还原用容斥即可,即
h_S=\sum_{T\subseteq S}\left(-1\right)^{|S|-|T|}H_T
其中,|S| 表示二进制数 S 中 1 的个数。
而在代码实现中,我们对每个二进制位(从低位到高位),如果该位是 1 ,则让它减去去掉这一位的那个值。这样一层层减下去,就还原了。
按位与
即求解 h_i=\sum_{k\oplus j=i}f_jg_k ,其中 \oplus 表示按位与。
正变换
和按位或类似,把 \subseteq 换成 \supseteq 即可。
逆变换
直接给还原式子吧,因为和上面的类似。
h_S=\sum_{T\supseteq S}\left(-1\right)^{|T|-|S|}H_T
而在代码实现中,我们对每个二进制位(从低位到高位),如果该位是 1 ,则让左边的值减去右边的值(注意方向与按位或相反)。
按位异或
即求解 h_i=\sum_{k\oplus j=i}f_jg_k ,其中 \oplus 表示按位异或。
直接给结论吧,想看详细推导的可以看此文实在看不懂。
正变换
\hat{f}_S=\sum_{T}f_T\times\left(-1\right)^{|S\cap T|}
代码实现步骤如下:
把数组分成两个一组(相邻两个数一组)。
对于每一组 (x,y) ,计算两个新数:
用这两个新数替换原来的数。
逆变换
f_S=\frac{1}{2^n}\sum_{T}\hat{f}_T\times\left(-1\right)^{|S\cap T|}
代码实现步骤如下:
把数组分成两个一组(相邻两个数一组)。
对于每一组 (x,y) ,计算两个新数:
新左边的数:\frac{x+y}{2} 。
新右边的数:\frac{x-y}{2} 。
用这两个新数替换原来的数。
其实就比正变换多除了个 2 。
:::success[Code]
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=1<<17;
const int mod=998244353,inv2=(mod+1)/2;
int n;
int tot;
int a[maxn],b[maxn];
int tmp_a[maxn],tmp_b[maxn];
void FWT_or(int tmp[],int inv,int len){
for(int mid=1;mid<len;mid<<=1){
for(int i=0;i<len;i+=mid*2){
for(int j=0;j<mid;j++){
int x=tmp[i+j],y=tmp[i+j+mid];
if(inv==1) tmp[i+j+mid]=(x+y)%mod;
else tmp[i+j+mid]=(y-x+mod)%mod;
}
}
}
}
void FWT_and(int tmp[],int inv,int len){
for(int mid=1;mid<len;mid<<=1){
for(int i=0;i<len;i+=mid*2){
for(int j=0;j<mid;j++){
int x=tmp[i+j],y=tmp[i+j+mid];
if(inv==1) tmp[i+j]=(x+y)%mod;
else tmp[i+j]=(x-y+mod)%mod;
}
}
}
}
void FWT_xor(int tmp[],int inv,int len){
for(int mid=1;mid<len;mid<<=1){
for(int i=0;i<len;i+=mid*2){
for(int j=0;j<mid;j++){
int x=tmp[i+j],y=tmp[i+j+mid];
tmp[i+j]=(x+y)%mod; tmp[i+j+mid]=(x-y+mod)%mod;
if(inv==-1){
tmp[i+j]=tmp[i+j]*inv2%mod;
tmp[i+j+mid]=tmp[i+j+mid]*inv2%mod;
}
}
}
}
}
void init(){
memcpy(tmp_a,a,sizeof(tmp_a));
memcpy(tmp_b,b,sizeof(tmp_b));
}
void OR(int tmp1[],int tmp2[],int len){
FWT_or(tmp1,1,len); FWT_or(tmp2,1,len);
for(int i=0;i<len;i++) tmp1[i]=tmp1[i]*tmp2[i]%mod;
FWT_or(tmp1,-1,len);
for(int i=0;i<len;i++) cout<<tmp1[i]<<" ";
cout<<endl;
}
void AND(int tmp1[],int tmp2[],int len){
FWT_and(tmp1,1,len); FWT_and(tmp2,1,len);
for(int i=0;i<len;i++) tmp1[i]=tmp1[i]*tmp2[i]%mod;
FWT_and(tmp1,-1,len);
for(int i=0;i<len;i++) cout<<tmp1[i]<<" ";
cout<<endl;
}
void XOR(int tmp1[],int tmp2[],int len){
FWT_xor(tmp1,1,len); FWT_xor(tmp2,1,len);
for(int i=0;i<len;i++) tmp1[i]=tmp1[i]*tmp2[i]%mod;
FWT_xor(tmp1,-1,len);
for(int i=0;i<len;i++) cout<<tmp1[i]<<" ";
cout<<endl;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n; tot=1<<n;
for(int i=0;i<tot;i++) cin>>a[i];
for(int i=0;i<tot;i++) cin>>b[i];
init(); OR(tmp_a,tmp_b,tot);
init(); AND(tmp_a,tmp_b,tot);
init(); XOR(tmp_a,tmp_b,tot);
return 0;
}
:::
多项式乘法逆
模板题
这里用的是牛顿迭代法,牛顿迭代一般用于求一个函数的根(即与 x 轴的交点)。
牛顿迭代公式:
x_{n+1}=x_n-\frac{f(x_n)}{f'(x_n)}
其中,f'(x) 是 f(x) 的导数。
可以理解为利用函数在某点的切线来逼近函数本身,并通过不断求解该切线与 x 轴的交点来迭代逼近方程的真实根。
回到本题,我们考虑移项。
F(x)-\frac{1}{G(x)}\equiv0\pmod{x^n}
然后令 H(t)=F(x)-\frac{1}{t} ,则有
H(G(x))\equiv0\pmod{x^n}
然后套牛顿迭代公式,得
G_{n+1}(x)=G_n(x)-\frac{H(G_{n}(x))}{H'(G_{n}(x))}
又由于 H(G_n(x))=F(x)-\frac{1}{G_n(x)} ,H'(G_n(x))=\frac{1}{G_n(x)^2} ,则有
G_{n+1}(x)=G_{n}(x)-\frac{F(x)-\frac{1}{G_n(x)}}{\frac{1}{G_n(x)^2}}
化简,得
G_{n+1}(x)=2G_n(x)-G_n(x)^2F(x)
然后递推做 NTT 即可。
:::success[Code]
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=3e6+10;
const int P=998244353,G=3;
int n;
int a[maxn],b[maxn],c[maxn];
int rev[maxn],bit,tot;
int qpow(int x,int k){
int base=1;
while(k){
if(k&1) base=base*x%P;
x=x*x%P;
k>>=1;
}
return base%P;
}
const int Gi=qpow(G,P-2);
void NTT(int a[],int type){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
int wi;
if(type==1) wi=qpow(G,(P-1)/(mid<<1));
else wi=qpow(Gi,(P-1)/(mid<<1));
for(int i=0;i<tot;i+=mid*2){
int wk=1;
for(int j=0;j<mid;j++,wk=wk*wi%P){
int x=a[i+j],y=wk*a[i+j+mid]%P;
a[i+j]=(x+y)%P;
a[i+j+mid]=(x-y+P)%P;
}
}
}
if(type==-1){
int inv_n=qpow(tot,P-2);
for(int i=0;i<tot;i++) a[i]=a[i]*inv_n%P;
}
}
void INV(int b[],int a[],int n){
if(n==1){
b[0]=qpow(a[0],P-2);
return;
}
INV(b,a,(n+1)>>1);
int len=1;
bit=0;
while(len<(n<<1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) c[i]=a[i]; for(int i=n;i<tot;i++) c[i]=0;
NTT(b,1); NTT(c,1);
for(int i=0;i<tot;i++) b[i]=(2-b[i]*c[i]%P+P)%P*b[i]%P;
NTT(b,-1);
for(int i=n;i<tot;i++) b[i]=0;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n;
for(int i=0;i<n;i++) cin>>a[i],a[i]%=P;
INV(b,a,n);
for(int i=0;i<n;i++) cout<<b[i]<<' ';
return 0;
}
:::
:::info[时间复杂度分析]
设 T(n) 为计算 n 次多项式逆元的总时间。则有:
T(n) = T\left(\left\lceil \frac{n}{2} \right\rceil\right) + O(n \log n)
展开递归:
T(n) = \sum_{k=0}^{\lfloor \log_2 n \rfloor} O\left( \frac{n}{2^k} \log \frac{n}{2^k} \right)
令 n_k = \frac{n}{2^k} ,则:
T(n) = O\left( \sum_{k=0}^{\lfloor \log n \rfloor} n_k \log n_k \right)
由于 n_k \log n_k = \frac{n}{2^k} (\log n - k) ,求和:
\sum_{k=0}^{\lfloor \log n \rfloor} \frac{n}{2^k} (\log n - k) = n \left( \log n \sum_{k=0}^{\lfloor \log n \rfloor} \frac{1}{2^k} - \sum_{k=0}^{\lfloor \log n \rfloor} \frac{k}{2^k} \right)
已知 \sum_{k=0}^{\infty} \frac{1}{2^k} = 2 ,\sum_{k=0}^{\infty} \frac{k}{2^k} = 2 ,因此括号内为 O(1) 。故:
T(n) = O(n \log n)
:::
任意模数多项式乘法逆
模板题
把 NTT 换成 MTT 即可。
这里写的是拆系数 FFT。
:::success[Code]
#include<bits/stdc++.h>
#define int long long
#define double long double
#define endl '\n'
using namespace std;
const int maxn=3e5+10;
const int P=1e9+7;
int n;
int a[maxn],b[maxn],c[maxn];
int rev[maxn],bit,tot;
int qpow(int x,int k){
int base=1;
while(k){
if(k&1) base=base*x%P;
x=x*x%P;
k>>=1;
}
return base%P;
}
const double pi=acos(-1);
struct node {
double x,y;
node operator+(const node&t)const {
return{x+t.x,y+t.y};
}
node operator-(const node&t)const {
return{x-t.x,y-t.y};
}
node operator*(const node&t)const {
return{x*t.x-y*t.y,x*t.y+y*t.x};
}
};
node buf1[maxn],buf2[maxn],buf3[maxn],buf4[maxn];
void FFT(node a[],int inv) {
for(int i=0;i<tot;i++) {
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1) {
node w1={cos(pi/mid),inv*sin(pi/mid)};
for(int i=0;i<tot;i+=mid*2) {
node wk={1,0};
for(int j=0;j<mid;j++,wk=wk*w1) {
node x=a[i+j],y=wk*a[i+j+mid];
a[i+j]=x+y;
a[i+j+mid]=x-y;
}
}
}
if(inv==-1) {
for(int i=0;i<tot;i++) {
a[i].x/=tot;
a[i].y/=tot;
}
}
}
int get_val(double x){
return (long long)(round(x))%P;
}
void MTT(int x[],int y[],int res[],int len){
bit=0;
while((1<<bit)<len) bit++;
tot=1<<bit;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<tot;i++) buf1[i]=buf2[i]=buf3[i]=buf4[i]={0,0};
for(int i=0;i<len;i++){
buf1[i]={(double)(x[i]&32767),0};
buf2[i]={(double)(x[i]>>15),0};
}
for(int i=0;i<len;i++){
buf3[i]={(double)(y[i]&32767),0};
buf4[i]={(double)(y[i]>>15),0};
}
FFT(buf1,1); FFT(buf2,1); FFT(buf3,1); FFT(buf4,1);
for(int i=0;i<tot;i++){
node tmp1=buf1[i]*buf3[i],tmp2=buf1[i]*buf4[i],tmp3=buf2[i]*buf3[i],tmp4=buf2[i]*buf4[i];
buf1[i]=tmp1;
buf2[i]=tmp2+tmp3;
buf3[i]=tmp4;
}
FFT(buf1,-1); FFT(buf2,-1); FFT(buf3,-1);
for(int i=0;i<len;i++){
int low=get_val(buf1[i].x),mid=get_val(buf2[i].x),high=get_val(buf3[i].x);
res[i]=(low+(mid<<15)+(high<<30))%P;
if(res[i]<0) res[i]+=P;
}
}
void INV(int b[],int a[],int n){
if(n==1){
b[0]=qpow(a[0],P-2);
return;
}
INV(b,a,(n+1)>>1);
int len=1;
while(len<(n<<1)) len<<=1;
static int tmp[maxn];
for(int i=0;i<n;i++) tmp[i]=a[i];
for(int i=n;i<len;i++) tmp[i]=0;
for(int i=n;i<len;i++) b[i]=0;
MTT(b,b,c,len); MTT(c,tmp,c,len);
for(int i=0;i<n;i++) b[i]=(2*b[i]%P-c[i]+P)%P;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n;
for(int i=0;i<n;i++) cin>>a[i],a[i]%=P;
INV(b,a,n);
for(int i=0;i<n;i++) cout<<b[i]<<' ';
return 0;
}
:::
多项式除法
模板题
要用一个很巧的变形。
F(x)=G(x)\times Q(x)+R(x)\\
F(\frac{1}{x})=G(\frac{1}{x})\times Q(\frac{1}{x})+R(\frac{1}{x})\\
x^nF(\frac{1}{x})=x^{n-m}G(\frac{1}{x})\times x^{m}Q(\frac{1}{x})+x^{n-m+1}\times x^{m-1}R(\frac{1}{x})
然后令 F_R(x)=x^n F(\frac{1}{x}) ,其中 n 为 F(x) 的次数。则有
F_R(x)=G_R(x)\times Q_R(x)+x^{n-m+1}\times R_R(x)\\
F_R(x)\equiv G_R(x)\times Q_R(x)\pmod{x^{n-m+1}}\\
Q_R(x)\equiv \frac{F_R(x)}{G_R(x)}\pmod{x^{n-m+1}}
这个式子用多项式求逆即可,问题来到如何根据 F(x) 求 F_R(x) 。
其实容易得到,F(x) 的第 i 位的系数与 F_R(x) 的第 n-i 位系数相同。
$$
R(x)=F(x)-G(x)\times Q(x)
$$
:::success[Code]
```
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=6e6+10;
const int P=998244353,G=3;
int n,m;
int f[maxn],g[maxn],tmpf[maxn];
int fr[maxn],gr[maxn];
int q[maxn],qr[maxn];
int c[maxn],d[maxn],e[maxn],tmp[maxn];
int rev[maxn],bit,tot;
int inv_gr[maxn];
int qpow(int x,int k){
int base=1;
while(k){
if(k&1) base=base*x%P;
x=x*x%P;
k>>=1;
}
return base%P;
}
const int Gi=qpow(G,P-2);
void NTT(int a[],int type){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
int wi;
if(type==1) wi=qpow(G,(P-1)/(mid<<1));
else wi=qpow(Gi,(P-1)/(mid<<1));
for(int i=0;i<tot;i+=mid*2){
int wk=1;
for(int j=0;j<mid;j++,wk=wk*wi%P){
int x=a[i+j],y=wk*a[i+j+mid]%P;
a[i+j]=(x+y)%P;
a[i+j+mid]=(x-y+P)%P;
}
}
}
if(type==-1){
int inv_n=qpow(tot,P-2);
for(int i=0;i<tot;i++) a[i]=a[i]*inv_n%P;
}
}
void INV(int b[],int a[],int n){
if(n==1){
b[0]=qpow(a[0],P-2);
return;
}
INV(b,a,(n+1)>>1);
int len=1;
bit=0;
while(len<(n<<1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) c[i]=a[i];
for(int i=n;i<tot;i++) c[i]=0;
for(int i=(n+1)/2;i<tot;i++) b[i]=0;
NTT(b,1); NTT(c,1);
for(int i=0;i<tot;i++) b[i]=(2-b[i]*c[i]%P+P)%P*b[i]%P;
NTT(b,-1);
for(int i=n;i<tot;i++) b[i]=0;
}
void MUL(int a[],int b[],int c[],int n,int m){
int len=1;
bit=0;
while(len<(n+m-1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) d[i]=a[i];
for(int i=n;i<tot;i++) d[i]=0;
for(int i=0;i<m;i++) e[i]=b[i];
for(int i=m;i<tot;i++) e[i]=0;
NTT(d,1); NTT(e,1);
for(int i=0;i<tot;i++) c[i]=d[i]*e[i]%P;
NTT(c,-1);
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n>>m;
for(int i=0;i<=n;i++) cin>>f[i],f[i]%=P;
for(int i=0;i<=m;i++) cin>>g[i],g[i]%=P;
for(int i=0;i<=n;i++) fr[i]=f[n-i];
for(int i=0;i<=m;i++) gr[i]=g[m-i];
INV(inv_gr,gr,n-m+1);
MUL(fr,inv_gr,qr,n+1,n-m+1);
for(int i=0;i<n-m+1;i++) q[i]=qr[n-m-i];
for(int i=0;i<n-m+1;i++) cout<<q[i]%P<<' ';
cout<<endl;
MUL(q,g,tmpf,n-m+1,m+1);
for(int i=0;i<m;i++) cout<<(f[i]-tmpf[i]+P)%P<<" ";
return 0;
}
```
:::
# 多项式求导
对于幂函数 $x^k$,有
$$
\frac{d}{dx} x^k=kx^{k-1}
$$
此外,由于求导是线性运算,有
$$
\frac{d}{dx}(f(x)+g(x))=f'(x)+g'(x)\\
\frac{d}{dx}(C\times f(x))=C\times f'(x)
$$
其中 $C$ 为任意常数。
则对于任意多项式
$$
A(x)=a_0+a_1x+a_2x^2+\ldots+a_{n-1}x^{n-1}
$$
的求导,有
$$
A'(x)=a_1+2a_2x+3a_3x^2+\ldots+na_nx^{n-1}
$$
实现如下:
```
void DIFF(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i-1]=i*f[i]%P;
g[n-1]=0;
}
```
# 多项式积分
求导的逆运算。
对于任意幂函数 $x^k$,有
$$
\int x^kdx=\frac{x^{k+1}}{k+1}+C
$$
其中 $k\neq-1$。
同求导,积分也是线性运算,即亦有
$$
\int(f(x)+g(x))dx=\int f(x)dx+\int g(x)dx\\
\int(C\times f(x))dx=C\times \int f(x)dx
$$
其中 $C$ 为常数。
则对于任意多项式
$$
A(x)=a_0+a_1x+a_2x^2+\ldots+a_{n-1}x^{n-1}
$$
的积分,有
$$
\int A(x)dx=a_0x+a_1\frac{x^2}{2}+a_2\frac{x^3}{3}+\ldots+a_n\frac{x^{n+1}}{n+1}+C
$$
实现如下:
```
void INTG(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i]=f[i-1]*qpow(i,P-2)%P;
g[0]=0;
}
```
# 多项式 ln
#### [模板题](https://www.luogu.com.cn/problem/P4725)
考虑求导,则有
$$
\begin{aligned}
\ln (A(x))'&=\ln'(A(x))A'(x)\\
&=\frac{A'(x)}{A(x)}
\end{aligned}
$$
这个东西用多项式乘法逆即可,然后再积分回去即可。
:::success[Code]
```
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=6e6+10;
const int P=998244353,G=3;
int n;
int fa[maxn],fb[maxn];
int c[maxn],d[maxn],e[maxn],tmp[maxn];
int rev[maxn],bit,tot;
int qpow(int x,int k){
int base=1;
while(k){
if(k&1) base=base*x%P;
x=x*x%P;
k>>=1;
}
return base%P;
}
const int Gi=qpow(G,P-2);
void NTT(int a[],int type){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
int wi;
if(type==1) wi=qpow(G,(P-1)/(mid<<1));
else wi=qpow(Gi,(P-1)/(mid<<1));
for(int i=0;i<tot;i+=mid*2){
int wk=1;
for(int j=0;j<mid;j++,wk=wk*wi%P){
int x=a[i+j],y=wk*a[i+j+mid]%P;
a[i+j]=(x+y)%P;
a[i+j+mid]=(x-y+P)%P;
}
}
}
if(type==-1){
int inv_n=qpow(tot,P-2);
for(int i=0;i<tot;i++) a[i]=a[i]*inv_n%P;
}
}
void INV(int b[],int a[],int n){
if(n==1){
b[0]=qpow(a[0],P-2);
return;
}
INV(b,a,(n+1)>>1);
int len=1;
bit=0;
while(len<(n<<1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) c[i]=a[i];
for(int i=n;i<tot;i++) c[i]=0;
for(int i=(n+1)/2;i<tot;i++) b[i]=0;
NTT(b,1); NTT(c,1);
for(int i=0;i<tot;i++) b[i]=(2-b[i]*c[i]%P+P)%P*b[i]%P;
NTT(b,-1);
for(int i=n;i<tot;i++) b[i]=0;
}
void MUL(int a[],int b[],int c[],int n,int m){
int len=1;
bit=0;
while(len<(n+m-1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) d[i]=a[i];
for(int i=n;i<tot;i++) d[i]=0;
for(int i=0;i<m;i++) e[i]=b[i];
for(int i=m;i<tot;i++) e[i]=0;
NTT(d,1); NTT(e,1);
for(int i=0;i<tot;i++) c[i]=d[i]*e[i]%P;
NTT(c,-1);
}
void DIFF(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i-1]=i*f[i]%P;
g[n-1]=0;
}
void INTG(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i]=f[i-1]*qpow(i,P-2)%P;
g[0]=0;
}
void LN(int f[],int g[],int n){
int df[maxn],inv_f[maxn],mul_res[maxn];
for(int i=0;i<n;i++) tmp[i]=f[i];
DIFF(f,df,n);
INV(inv_f,tmp,n);
MUL(df,inv_f,mul_res,n,n);
INTG(mul_res,g,n);
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n;
for(int i=0;i<n;i++) cin>>fa[i],fa[i]%=P;
LN(fa,fb,n);
for(int i=0;i<n;i++) cout<<fb[i]%P<<' ';
return 0;
}
```
:::
# 多项式 exp
#### [模板题](https://www.luogu.com.cn/problem/P4726)
令 $\ln G(x)=F(x)$,则有
$$
F(x)-\ln G(x)=0
$$
考虑牛顿迭代,令 $H(t)=F(x)-\ln t$,则有
$$
H(G(x))=0\\
G_{n+1}(x)=G_n(x)-\frac{H(G_{n}(x))}{H'(G_{n}(x))}\\
G_{n+1}(x)=G_n(x)-\frac{F(x)-\ln G_n(x)}{-\frac{1}{G_n(x)}}\\
G_{n+1}(x)=G_n(x)(1+F(x)-\ln G_n(x))\\
$$
多项式乘法 + 多项式 ln + 多项式乘法逆即可。
:::success[Code]
```
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=6e6+10;
const int P=998244353,G=3;
int n;
int fa[maxn],fb[maxn];
int c[maxn],d[maxn],e[maxn],tmp[maxn];
int rev[maxn],bit,tot;
int ln_G[maxn],tmp_G[maxn];
int qpow(int x,int k){
int base=1;
while(k){
if(k&1) base=base*x%P;
x=x*x%P;
k>>=1;
}
return base%P;
}
const int Gi=qpow(G,P-2);
void NTT(int a[],int type){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
int wi;
if(type==1) wi=qpow(G,(P-1)/(mid<<1));
else wi=qpow(Gi,(P-1)/(mid<<1));
for(int i=0;i<tot;i+=mid*2){
int wk=1;
for(int j=0;j<mid;j++,wk=wk*wi%P){
int x=a[i+j],y=wk*a[i+j+mid]%P;
a[i+j]=(x+y)%P;
a[i+j+mid]=(x-y+P)%P;
}
}
}
if(type==-1){
int inv_n=qpow(tot,P-2);
for(int i=0;i<tot;i++) a[i]=a[i]*inv_n%P;
}
}
void INV(int b[],int a[],int n){
if(n==1){
b[0]=qpow(a[0],P-2);
return;
}
INV(b,a,(n+1)>>1);
int len=1;
bit=0;
while(len<(n<<1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) c[i]=a[i];
for(int i=n;i<tot;i++) c[i]=0;
for(int i=(n+1)/2;i<tot;i++) b[i]=0;
NTT(b,1); NTT(c,1);
for(int i=0;i<tot;i++) b[i]=(2-b[i]*c[i]%P+P)%P*b[i]%P;
NTT(b,-1);
for(int i=n;i<tot;i++) b[i]=0;
}
void MUL(int a[],int b[],int c[],int n,int m){
int len=1;
bit=0;
while(len<(n+m-1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) d[i]=a[i];
for(int i=n;i<tot;i++) d[i]=0;
for(int i=0;i<m;i++) e[i]=b[i];
for(int i=m;i<tot;i++) e[i]=0;
NTT(d,1); NTT(e,1);
for(int i=0;i<tot;i++) c[i]=d[i]*e[i]%P;
NTT(c,-1);
}
void DIFF(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i-1]=i*f[i]%P;
g[n-1]=0;
}
void INTG(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i]=f[i-1]*qpow(i,P-2)%P;
g[0]=0;
}
void LN(int f[],int g[],int n){
int df[maxn],inv_f[maxn],mul_res[maxn];
for(int i=0;i<n;i++) tmp[i]=f[i];
DIFF(f,df,n);
INV(inv_f,tmp,n);
MUL(df,inv_f,mul_res,n,n);
INTG(mul_res,g,n);
}
void EXP(int g[],int f[],int n){
if(n==1){
g[0]=1;
return;
}
EXP(g,f,(n+1)>>1);
LN(g,ln_G,n);
for(int i=0;i<n;i++) tmp_G[i]=(f[i]-ln_G[i]+P)%P;
tmp_G[0]=(tmp_G[0]+1)%P;
MUL(g,tmp_G,g,n,n);
for(int i=n;i<tot;i++) g[i]=0;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n;
for(int i=0;i<n;i++) cin>>fa[i],fa[i]%=P;
EXP(fb,fa,n);
for(int i=0;i<n;i++) cout<<fb[i]%P<<' ';
return 0;
}
```
:::
# 多项式开根
## 普通版
#### [模板题](https://www.luogu.com.cn/problem/P5205)
注意到 $A(x)^{\frac{1}{2}}=\exp(\frac{1}{2}\ln(A(x)))$。
然后套上面两个模板即可。
:::success[Code]
```cpp
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=1<<19;
const int P=998244353,G=3;
int n;
int fa[maxn],fb[maxn];
int c[maxn],d[maxn],e[maxn],tmp[maxn],tmp2[maxn];
int rev[maxn],bit,tot;
int ln_G[maxn],tmp_G[maxn];
int qpow(int x,int k){
int base=1;
while(k){
if(k&1) base=base*x%P;
x=x*x%P;
k>>=1;
}
return base%P;
}
const int Gi=qpow(G,P-2);
void NTT(int a[],int type){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
int wi;
if(type==1) wi=qpow(G,(P-1)/(mid<<1));
else wi=qpow(Gi,(P-1)/(mid<<1));
for(int i=0;i<tot;i+=mid*2){
int wk=1;
for(int j=0;j<mid;j++,wk=wk*wi%P){
int x=a[i+j],y=wk*a[i+j+mid]%P;
a[i+j]=(x+y)%P;
a[i+j+mid]=(x-y+P)%P;
}
}
}
if(type==-1){
int inv_n=qpow(tot,P-2);
for(int i=0;i<tot;i++) a[i]=a[i]*inv_n%P;
}
}
void INV(int b[],int a[],int n){
if(n==1){
b[0]=qpow(a[0],P-2);
return;
}
INV(b,a,(n+1)>>1);
int len=1;
bit=0;
while(len<(n<<1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) c[i]=a[i];
for(int i=n;i<tot;i++) c[i]=0;
for(int i=(n+1)/2;i<tot;i++) b[i]=0;
NTT(b,1); NTT(c,1);
for(int i=0;i<tot;i++) b[i]=(2-b[i]*c[i]%P+P)%P*b[i]%P;
NTT(b,-1);
for(int i=n;i<tot;i++) b[i]=0;
}
void MUL(int a[],int b[],int c[],int n,int m){
int len=1;
bit=0;
while(len<(n+m-1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) d[i]=a[i];
for(int i=n;i<tot;i++) d[i]=0;
for(int i=0;i<m;i++) e[i]=b[i];
for(int i=m;i<tot;i++) e[i]=0;
NTT(d,1); NTT(e,1);
for(int i=0;i<tot;i++) c[i]=d[i]*e[i]%P;
NTT(c,-1);
}
void DIFF(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i-1]=i*f[i]%P;
g[n-1]=0;
}
void INTG(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i]=f[i-1]*qpow(i,P-2)%P;
g[0]=0;
}
void LN(int f[],int g[],int n){
static int df[maxn],inv_f[maxn],mul_res[maxn];
for(int i=0;i<n;i++) tmp[i]=f[i];
DIFF(f,df,n);
INV(inv_f,tmp,n);
MUL(df,inv_f,mul_res,n,n);
INTG(mul_res,g,n);
}
void EXP(int g[],int f[],int n){
if(n==1){
g[0]=1;
return;
}
EXP(g,f,(n+1)>>1);
for(int i=0;i<n;i++) tmp_G[i]=g[i];
LN(tmp_G,ln_G,n);
for(int i=0;i<n;i++) tmp_G[i]=(f[i]-ln_G[i]+P)%P;
tmp_G[0]=(tmp_G[0]+1)%P;
MUL(g,tmp_G,g,n,n);
for(int i=n;i<tot;i++) g[i]=0;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n;
for(int i=0;i<n;i++) cin>>fa[i],fa[i]%=P;
LN(fa,fb,n);
for(int i=0;i<n;i++) fa[i]=fb[i]*qpow(2,P-2)%P;
EXP(fb,fa,n);
for(int i=0;i<n;i++) cout<<fb[i]%P<<' ';
return 0;
}
```
:::
## 加强版
#### [模板题](https://www.luogu.com.cn/problem/P5277)
:::info[二次剩余的讲解]
#### [模板题](https://www.luogu.com.cn/problem/P5491)
二次剩余指形如下面方程
$$
x^2\equiv n\pmod p
$$
其中 $p$ 为奇素数。上述方程无解的非 $0$ 数 $n$ 称作非二次剩余,否则为二次剩余。
如何判定一个数是否为二次剩余。
根据费马小定理,我们得到 $a^{p-1}\equiv1\pmod p$,由于 $p$ 为奇素数,所以有 $\left(a^{\frac{p-1}{2}}\right)^2\equiv1\pmod p$。
而在模 $p$ 意义下,一个数的平方为 $1$,这个数只能为 $1$ 或者 $-1$。
则分为两种情况讨论。
1. 若 $a$ 是二次剩余。
那么存在某个 $x$ 则有
$$
a^{\frac{p-1}{2}}\equiv \left(x^2\right)^{\frac{p-1}{2}}=x^{p-1}\equiv 1\pmod p
$$
2. 若 $a$ 不是二次剩余。
那么它不可能等于 $1$(因为上面说了二次剩余才会得 1),而结果又只能是 $1$ 或 $-1$,所以结果只能是 $-1$。
综上,当 $a^{\frac{p-1}{2}}\equiv 1\pmod p$ 时,$a$ 为二次剩余。
那如何求解上面的方程呢?
这里讲解 Cipolla 算法。
Cipolla 算法的原理是通过随机化 + 检验找到一个数 $a$,使得 $a^2-n$ 是非二次剩余,这样的概率是接近 $50\%$ 的,证明见 [oi-wiki](https://oi-wiki.org/math/number-theory/quad-residue/#euler-%E5%88%A4%E5%88%AB%E6%B3%95)。
定义一个复数 $i$,满足 $i^2\equiv a^2-n\pmod p$。
那么有 $(a+i)^{p+1}\equiv n\pmod p$,证明见[此文](https://www.luogu.com.cn/article/ify2j98h)。
则有一个解为 $(a+i)^{\frac{p+1}{2}}$,另一个解为其相反数。而 $(a+i)^{\frac{p+1}{2}}$ 的虚部必定为 $0$,证明亦见上面的文章。
:::success[Code]
```
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=1e5+10;
int t,mod,w;
struct node{
int x,y;
node operator*(const node& t) const{
return {(x*t.x%mod+y*t.y%mod*w%mod)%mod,(x*t.y%mod+y*t.x%mod)%mod};
}
};
int qpow1(int x,int k){
int base=1;
while(k){
if(k&1) base=base*x%mod;
x=x*x%mod;
k>>=1;
}
return base;
}
node qpow2(node x,int k){
node base={1,0};
while(k){
if(k&1) base=base*x;
x=x*x;
k>>=1;
}
return base;
}
bool check(int x){
if(x%mod==0) return 1;
return qpow1(x,(mod-1)/2)==1;
}
void Cipolla(int n,int &x1,int &x2){
int a=rand()%mod; w=(a*a%mod-n+mod)%mod;
while(check((a*a-n+mod)%mod)){
a=rand()%mod;
w=(a*a%mod-n+mod)%mod;
}
node res=qpow2({a,1},(mod+1)/2);
x1=res.x;
x2=(-x1+mod)%mod;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
srand(time(0));
cin>>t;
while(t--){
int n;
cin>>n>>mod;
if(n%mod==0){
cout<<"0\n";
continue;
}
if(!check(n)){
cout<<"Hola!\n";
continue;
}
int x1,x2;
Cipolla(n,x1,x2);
if(x1>x2) swap(x1,x2);
if(x1==x2) cout<<x1<<endl;
else cout<<x1<<" "<<x2<<endl;
}
return 0;
}
```
:::
由于多项式 ln 需保证常数项为 $1$,而此题不保证。
考虑牛顿迭代法。
$$
A(x)-B(x)^2\equiv 0\pmod{x^n}\\
C(t)=A(x)-t^2\\
C(B(x))\equiv 0\pmod{ x^n}\\
B_{n+1}(x)=B_n(x)-\frac{C(B_n(x))}{C'(B_n(x))}\\
B_{n+1}(x)=B_n(x)+\frac{A(x)-B_n(x)^2}{2B_n(x)}
$$
递推做即可。当 $n=1$ 时,做一遍二次剩余即可。
:::success[Code]
```
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=3e6+10;
const int P=998244353,G=3;
int n;
int a[maxn],b[maxn],c[maxn],d[maxn];
int rev[maxn],bit,tot;
int w;
struct node{
int x,y;
node operator*(const node& t) const{
return {(x*t.x%P+y*t.y%P*w%P)%P,(x*t.y%P+y*t.x%P)%P};
}
};
int qpow1(int x,int k){
int base=1;
while(k){
if(k&1) base=base*x%P;
x=x*x%P;
k>>=1;
}
return base;
}
node qpow2(node x,int k){
node base={1,0};
while(k){
if(k&1) base=base*x;
x=x*x;
k>>=1;
}
return base;
}
bool check(int x){
if(x%P==0) return 1;
return qpow1(x,(P-1)/2)==1;
}
int Cipolla(int n){
int a=rand()%P; w=(a*a%P-n+P)%P;
while(check((a*a-n+P)%P)){
a=rand()%P;
w=(a*a%P-n+P)%P;
}
node res=qpow2({a,1},(P+1)/2);
int x1=res.x,x2=(-x1+P)%P;
return min(x1,x2);
}
const int Gi=qpow1(G,P-2);
void NTT(int a[],int type){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
int wi;
if(type==1) wi=qpow1(G,(P-1)/(mid<<1));
else wi=qpow1(Gi,(P-1)/(mid<<1));
for(int i=0;i<tot;i+=mid*2){
int wk=1;
for(int j=0;j<mid;j++,wk=wk*wi%P){
int x=a[i+j],y=wk*a[i+j+mid]%P;
a[i+j]=(x+y)%P;
a[i+j+mid]=(x-y+P)%P;
}
}
}
if(type==-1){
int inv_n=qpow1(tot,P-2);
for(int i=0;i<tot;i++) a[i]=a[i]*inv_n%P;
}
}
void INV(int g[],int f[],int n){
if(n==1){
g[0]=qpow1(f[0],P-2);
return;
}
INV(g,f,(n+1)>>1);
int len=1;
bit=0;
while(len<(n<<1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) c[i]=f[i]; for(int i=n;i<tot;i++) c[i]=0;
NTT(g,1); NTT(c,1);
for(int i=0;i<tot;i++) g[i]=(2-g[i]*c[i]%P+P)%P*g[i]%P;
NTT(g,-1);
for(int i=n;i<tot;i++) g[i]=0;
}
int tmp_b1[maxn],tmp_b2[maxn];
void SQRT(int g[],int f[],int n){
if(n==1){
g[0]=Cipolla(f[0]);
return;
}
SQRT(g,f,(n+1)>>1);
for(int i=0;i<2*n;i++) d[i]=tmp_b1[i]=tmp_b2[i]=0;
for(int i=0;i<n;i++) tmp_b1[i]=g[i];
INV(d,tmp_b1,n);
int len=1;
bit=0;
while(len<(n<<1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) tmp_b2[i]=f[i];
NTT(d,1); NTT(tmp_b2,1);
for(int i=0;i<tot;i++) d[i]=d[i]*tmp_b2[i]%P;
NTT(d,-1);
int inv2=qpow1(2,P-2);
for(int i=0;i<tot;i++) g[i]=(g[i]+d[i])%P*inv2%P;
for(int i=n;i<tot;i++) g[i]=0;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n;
for(int i=0;i<n;i++) cin>>a[i],a[i]%=P;
SQRT(b,a,n);
for(int i=0;i<n;i++) cout<<b[i]<<' ';
return 0;
}
```
:::
# 多项式 k 次方根
## 普通版
:::info[题面]{open}
给定一个 $n-1$ 次多项式
$$
A(x)=a_0+a_1x+a_2x^2+\ldots+a_{n-1}x^{n-1}
$$
保证 $\red{a_0=1}$,且存在多项式 $B(x)$ 使得
$$
B(x)^k \equiv A(x) \pmod{x^n}
$$
在模 $998244353$ 的意义下成立。
你需要输出这个 $B(x)$,若存在多组解,输出字典序最小的解。
:::
同开根,有
$$
F(x)^{\frac{1}{k}}=\exp(\frac{1}{k}\ln F(x))
$$
直接多项式 exp+ln 做即可,代码就不给了。
## 加强版
:::info[题面]{open}
给定一个 $n-1$ 次多项式
$$
A(x)=a_0+a_1x+a_2x^2+\ldots+a_{n-1}x^{n-1}
$$
保证 $\red{a_0\neq0}$,且存在多项式 $B(x)$ 使得
$$
B(x)^k \equiv A(x) \pmod{x^n}
$$
在模 $998244353$ 的意义下成立。
你需要输出这个 $B(x)$,若存在多组解,输出字典序最小的解。
:::
依旧牛顿迭代法。
$$
A(x)-B(x)^k\equiv 0\pmod{x^n}\\
C(t)=A(x)-t^k\\
C(B(x))\equiv 0\pmod{ x^n}\\
B_{n+1}(x)=B_n(x)-\frac{C(B_n(x))}{C'(B_n(x))}\\
B_{n+1}(x)=B_n(x)+\frac{A(x)-B(x)^k}{kB(x)^{k-1}}
$$
递推做即可。当 $n=1$ 时,做一遍 $N$ 次剩余即可。
代码与多项式开根(加强版)类似,就不写了。
:::info[n 次剩余的讲解(模素数版)]
求解如下方程
$$
x^n\equiv a\pmod p
$$
其中 $p$ 为素数。
因为 $p$ 为素数,则模 $p$ 存在原根 $g$。设 $x\equiv g^y\pmod p,a\equiv g^t \pmod p$,则原方程变为:
$$
g^{ny}\equiv g^t\pmod p
$$
即
$$
ny\equiv t\pmod{p-1}
$$
其中,$t$ 可以通过 BSGS 求解 $g^t\equiv a\pmod p$ 得到。
然后令 $d=\gcd(n,m)$,则上述方程有解当且仅当 $d\mid t$。
若有解,则解集为:
$$
y=y_0+k\times\frac{p-1}{d},k\in[0,d-1]
$$
其中,$y_0$ 是一个特解。如何求解这个特解呢?
设 $n'=\frac{n}{d},m'=\frac{p-1}{d},t'=\frac{t}{d}$,则方程化为
$$
n'y\equiv t'\pmod{m'}
$$
因为 $\gcd(n',m')=1$,则 $n'$ 在模 $m'$ 意义下有逆元。于是 $y_0=t'\times (n')^{-1}$。
返回去求解 $x$,特解 $y_0$ 对应的 $x_0\equiv g^{y_0}\pmod{p}$。
因为 $y=y_0+k\times \frac{p-1}{d}$,所以
$$
x=g^{y_0+k\times\frac{p-1}{d}}=g^{y_0}\times \left(g^{\frac{p-1}{d}}\right)^k=x_0\times \omega^k
$$
其中 $\omega\equiv g^{\frac{p-1}{d}}\pmod p$,即 $\omega$ 是模 $p$ 意义下的 $d$ 次单位根。
综上,所有解为:
$$
x\equiv x_0\times \omega^k\pmod p,k\in[0,d-1]
$$
时间复杂度 $O(\sqrt p+\gcd(n,p-1))$。
:::success[Code]
```
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=1e5+10;
int p;
int q[maxn],cnt;
unordered_map<int,int> mp;
int exgcd(int a,int b,int &x,int &y){
if(!b){
x=1; y=0;
return a;
}
int gcd=exgcd(b,a%b,x,y);
int t=x;
x=y; y=t-a/b*y;
return gcd;
}
int bsgs(int x,int y){
mp.clear();
if(x==y) return 1;
int t=sqrt(p);
for(int i=1,j=x;i<=t;i++,j=j*x%p) mp[y*j%p]=i;
int s=1;
for(int i=1;i<=t;i++) s=s*x%p;
for(int i=1,j=s;i<=t;i++,j=j*s%p){
if(mp[j]) return i*t-mp[j];
}
return -1;
}
int qpow(int x,int k,int mod){
int base=1;
while(k){
if(k&1) base=base*x%mod;
x=x*x%mod;
k>>=1;
}
return base;
}
void fac(int x){
cnt=0;
if(x%2==0){
while(x%2==0) x/=2;
q[++cnt]=2;
}
for(int i=3;i*i<=x;i+=2){
if(x%i==0){
while(x%i==0) x/=i;
q[++cnt]=i;
}
}
if(x!=1) q[++cnt]=x;
}
int find_g(int x){
fac(x);
int i=2;
while(1){
bool flag=1;
for(int j=1;j<=cnt;j++){
if(qpow(i,x/q[j],p)==1){
flag=0; break;
}
}
if(flag) return i;
i++;
}
}
int inv_mod(int a,int mod){
int x,y;
int g=exgcd(a,mod,x,y);
if(g!=1) return -1;
return (x%mod+mod)%mod;
}
int find(int n,int a){
int g=find_g(p-1);
int t=bsgs(g,a),d=__gcd(n,p-1);
if(t%d!=0) return -1;
int mod=(p-1)/d;
int k0=t/d%mod*inv_mod(n/d,mod)%mod;
return qpow(g,k0,p);
}
void solve(int n,int a){
int x0=find(n,a);
if(x0==-1){
cout<<-1<<endl;
return;
}
int d=__gcd(n,p-1),g=find_g(p-1);
int omega=qpow(g,(p-1)/d,p);
vector<int> ans;
int base=x0;
for(int i=0;i<d;i++){
ans.push_back(base);
base=base*omega%p;
}
sort(ans.begin(),ans.end());
for(int v:ans) cout<<v<<' ';
cout<<endl;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
int t,n,k;
while(t--){
cin>>n>>p>>k;
solve(n,k);
}
return 0;
}
```
:::
# 多项式快速幂
## 普通版
#### [模板题](https://www.luogu.com.cn/problem/P5245)
和多项式开根普通版一样(可以理解为推广?),有如下式子:
$$
A(x)^k=\exp(k\ln(A(x)))
$$
值得注意的是,这道题 $1\le k\le10^{10^5}$,要先对 $k$ 取模再算。
:::success[Code]
```
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=1<<19;
const int P=998244353,G=3;
int n,k;
string K;
int fa[maxn],fb[maxn];
int c[maxn],d[maxn],e[maxn],tmp[maxn],tmp2[maxn];
int rev[maxn],bit,tot;
int ln_G[maxn],tmp_G[maxn];
int qpow(int x,int k){
int base=1;
while(k){
if(k&1) base=base*x%P;
x=x*x%P;
k>>=1;
}
return base%P;
}
const int Gi=qpow(G,P-2);
void NTT(int a[],int type){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
int wi;
if(type==1) wi=qpow(G,(P-1)/(mid<<1));
else wi=qpow(Gi,(P-1)/(mid<<1));
for(int i=0;i<tot;i+=mid*2){
int wk=1;
for(int j=0;j<mid;j++,wk=wk*wi%P){
int x=a[i+j],y=wk*a[i+j+mid]%P;
a[i+j]=(x+y)%P;
a[i+j+mid]=(x-y+P)%P;
}
}
}
if(type==-1){
int inv_n=qpow(tot,P-2);
for(int i=0;i<tot;i++) a[i]=a[i]*inv_n%P;
}
}
void INV(int b[],int a[],int n){
if(n==1){
b[0]=qpow(a[0],P-2);
return;
}
INV(b,a,(n+1)>>1);
int len=1;
bit=0;
while(len<(n<<1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) c[i]=a[i];
for(int i=n;i<tot;i++) c[i]=0;
for(int i=(n+1)/2;i<tot;i++) b[i]=0;
NTT(b,1); NTT(c,1);
for(int i=0;i<tot;i++) b[i]=(2-b[i]*c[i]%P+P)%P*b[i]%P;
NTT(b,-1);
for(int i=n;i<tot;i++) b[i]=0;
}
void MUL(int a[],int b[],int c[],int n,int m){
int len=1;
bit=0;
while(len<(n+m-1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) d[i]=a[i];
for(int i=n;i<tot;i++) d[i]=0;
for(int i=0;i<m;i++) e[i]=b[i];
for(int i=m;i<tot;i++) e[i]=0;
NTT(d,1); NTT(e,1);
for(int i=0;i<tot;i++) c[i]=d[i]*e[i]%P;
NTT(c,-1);
}
void DIFF(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i-1]=i*f[i]%P;
g[n-1]=0;
}
void INTG(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i]=f[i-1]*qpow(i,P-2)%P;
g[0]=0;
}
void LN(int f[],int g[],int n){
static int df[maxn],inv_f[maxn],mul_res[maxn];
for(int i=0;i<n;i++) tmp[i]=f[i];
DIFF(f,df,n);
INV(inv_f,tmp,n);
MUL(df,inv_f,mul_res,n,n);
INTG(mul_res,g,n);
}
void EXP(int g[],int f[],int n){
if(n==1){
g[0]=1;
return;
}
EXP(g,f,(n+1)>>1);
for(int i=0;i<n;i++) tmp_G[i]=g[i];
LN(tmp_G,ln_G,n);
for(int i=0;i<n;i++) tmp_G[i]=(f[i]-ln_G[i]+P)%P;
tmp_G[0]=(tmp_G[0]+1)%P;
MUL(g,tmp_G,g,n,n);
for(int i=n;i<tot;i++) g[i]=0;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n>>K;
for(int i=0;i<(int)K.size();i++) k=(k*10+(K[i]-'0'))%P;
for(int i=0;i<n;i++) cin>>fa[i],fa[i]%=P;
LN(fa,fb,n);
for(int i=0;i<n;i++) fa[i]=fb[i]*k%P;
EXP(fb,fa,n);
for(int i=0;i<n;i++) cout<<fb[i]%P<<' ';
return 0;
}
```
:::
## 加强版
#### [模板题](https://www.luogu.com.cn/problem/P5273)
不能直接做的原因和多项式开根加强版一致。
一个朴素的想法是算之前先除以一个 $A(0)$,输出时再乘上,则有
$$
A(x)^k=(\frac{A(x)}{A(0)})^k\times A(0)^k
$$
但万一 $A(0)=0$ 呢?
考虑找到第一个系数不为 $0$ 的项数,令这一项的次数为 $t$。然后把整个多项式降 $t$ 次,再做上述操作即可,最后再升 $t$ 次。则有
$$
A(x)^k = \left( \frac{A(x)}{a_t x^t} \right)^k \times x^{tk} a_t^k
$$
其中 $a_t$ 是 $x^t$ 的系数。
值得注意的是,$x^{tk}$ 中的 $k$ 不应取模,如果 $tk>n$,则输出全 $0$。而由费马小定理可得,$a_t^k$ 中的 $k$ 应对 $p-1$ 取模,其它的 $k$ 对 $p$ 取模。
:::success[Code]
```
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=1<<19;
const int P=998244353,G=3;
int n,k1,k2,k;
string K;
int fa[maxn],fb[maxn];
int c[maxn],d[maxn],e[maxn],tmp[maxn],tmp2[maxn];
int rev[maxn],bit,tot;
int ln_G[maxn],tmp_G[maxn];
int qpow(int x,int k){
int base=1;
while(k){
if(k&1) base=base*x%P;
x=x*x%P;
k>>=1;
}
return base%P;
}
const int Gi=qpow(G,P-2);
void NTT(int a[],int type){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
int wi;
if(type==1) wi=qpow(G,(P-1)/(mid<<1));
else wi=qpow(Gi,(P-1)/(mid<<1));
for(int i=0;i<tot;i+=mid*2){
int wk=1;
for(int j=0;j<mid;j++,wk=wk*wi%P){
int x=a[i+j],y=wk*a[i+j+mid]%P;
a[i+j]=(x+y)%P;
a[i+j+mid]=(x-y+P)%P;
}
}
}
if(type==-1){
int inv_n=qpow(tot,P-2);
for(int i=0;i<tot;i++) a[i]=a[i]*inv_n%P;
}
}
void INV(int b[],int a[],int n){
if(n==1){
b[0]=qpow(a[0],P-2);
return;
}
INV(b,a,(n+1)>>1);
int len=1;
bit=0;
while(len<(n<<1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) c[i]=a[i];
for(int i=n;i<tot;i++) c[i]=0;
for(int i=(n+1)/2;i<tot;i++) b[i]=0;
NTT(b,1); NTT(c,1);
for(int i=0;i<tot;i++) b[i]=(2-b[i]*c[i]%P+P)%P*b[i]%P;
NTT(b,-1);
for(int i=n;i<tot;i++) b[i]=0;
}
void MUL(int a[],int b[],int c[],int n,int m){
int len=1;
bit=0;
while(len<(n+m-1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) d[i]=a[i];
for(int i=n;i<tot;i++) d[i]=0;
for(int i=0;i<m;i++) e[i]=b[i];
for(int i=m;i<tot;i++) e[i]=0;
NTT(d,1); NTT(e,1);
for(int i=0;i<tot;i++) c[i]=d[i]*e[i]%P;
NTT(c,-1);
}
void DIFF(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i-1]=i*f[i]%P;
g[n-1]=0;
}
void INTG(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i]=f[i-1]*qpow(i,P-2)%P;
g[0]=0;
}
void LN(int f[],int g[],int n){
static int df[maxn],inv_f[maxn],mul_res[maxn];
for(int i=0;i<n;i++) tmp[i]=f[i];
DIFF(f,df,n);
INV(inv_f,tmp,n);
MUL(df,inv_f,mul_res,n,n);
INTG(mul_res,g,n);
}
void EXP(int g[],int f[],int n){
if(n==1){
g[0]=1;
return;
}
EXP(g,f,(n+1)>>1);
for(int i=0;i<n;i++) tmp_G[i]=g[i];
LN(tmp_G,ln_G,n);
for(int i=0;i<n;i++) tmp_G[i]=(f[i]-ln_G[i]+P)%P;
tmp_G[0]=(tmp_G[0]+1)%P;
MUL(g,tmp_G,g,n,n);
for(int i=n;i<tot;i++) g[i]=0;
}
void move(int f[],int g[],int t){
for(int i=0;i<n;i++){
if(0<=i-t&&i-t<=n-1) g[i]=f[i-t];
}
}
int t=-1;
int tmp_a[maxn];
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n>>K;
for(int i=0;i<n;i++){
cin>>fa[i],fa[i]%=P;
if(fa[i]&&t==-1) t=i;
}
if(t==-1){
for(int i=0;i<n;i++) cout<<"0 ";
return 0;
}
for(int i=0;i<(int)K.size();i++){
k1=(k1*10+(K[i]-'0'))%P; k2=(k2*10+(K[i]-'0'))%(P-1);
}
for(int i=0;i<(int)K.size();i++){
k=k*10+(K[i]-'0');
if(k>1e9){
k=-1; break;
}
}
if((k==-1&&t)||k*t>=n){
for(int i=0;i<n;i++) cout<<"0 ";
return 0;
}
move(fa,tmp_a,-t);
for(int i=0;i<n;i++) tmp_a[i]=tmp_a[i]*qpow(fa[t],P-2)%P;
LN(tmp_a,fb,n-t);
for(int i=0;i<n-t;i++) tmp_a[i]=fb[i]*k1%P;
EXP(fb,tmp_a,n-t);
for(int i=0;i<n;i++) fb[i]=fb[i]*qpow(fa[t],k2)%P;
fill(tmp_a,tmp_a+k*t,0);
move(fb,tmp_a,k*t);
for(int i=0;i<n;i++) cout<<tmp_a[i]<<' ';
return 0;
}
```
:::
时间复杂度 $O(n\log n)$。
# 分治 FFT
#### [模板题](https://www.luogu.com.cn/problem/P4721)
设 $F(x)=\sum_{i=0}^{\infty}f_ix^i$,$G(x)=\sum_{i=1}^{\infty}g_ix^i$(其中 $g_0=0$)。\
对于 $i\ge 1$,递推式为 $f_i=\sum_{j=1}^{i}f_{i-j}g_j=\sum_{j=0}^{i}f_{i-j}g_j$。于是
$$
\begin{aligned}
F(x)&=f_0+\sum_{i=1}^{\infty}\left(\sum_{j=0}^{i}f_{i-j}g_j\right)x^i \\
&=1+\sum_{i=0}^{\infty}\left(\sum_{j=0}^{i}f_{i-j}g_j\right)x^i - f_0g_0x^0 \\
&=1+\left(\sum_{i=0}^{\infty}\sum_{j=0}^{i}f_{i-j}g_jx^i\right)\\
&=1+F(x)G(x)
\end{aligned}
$$
因此 $F(x)-F(x)G(x)=1$,即 $F(x)(1-G(x))=1$,所以
$$
F(x)=\frac{1}{1-G(x)}
$$
多项式求逆即可。
:::success[Code]
```
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=3e6+10;
const int P=998244353,G=3;
int n;
int a[maxn],b[maxn],c[maxn],h[maxn];
int rev[maxn],bit,tot;
int qpow(int x,int k){
int base=1;
while(k){
if(k&1) base=base*x%P;
x=x*x%P;
k>>=1;
}
return base%P;
}
const int Gi=qpow(G,P-2);
void NTT(int a[],int type){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
int wi;
if(type==1) wi=qpow(G,(P-1)/(mid<<1));
else wi=qpow(Gi,(P-1)/(mid<<1));
for(int i=0;i<tot;i+=mid*2){
int wk=1;
for(int j=0;j<mid;j++,wk=wk*wi%P){
int x=a[i+j],y=wk*a[i+j+mid]%P;
a[i+j]=(x+y)%P;
a[i+j+mid]=(x-y+P)%P;
}
}
}
if(type==-1){
int inv_n=qpow(tot,P-2);
for(int i=0;i<tot;i++) a[i]=a[i]*inv_n%P;
}
}
void INV(int b[],int a[],int n){
if(n==1){
b[0]=qpow(a[0],P-2);
return;
}
INV(b,a,(n+1)>>1);
int len=1;
bit=0;
while(len<(n<<1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) c[i]=a[i]; for(int i=n;i<tot;i++) c[i]=0;
NTT(b,1); NTT(c,1);
for(int i=0;i<tot;i++) b[i]=(2-b[i]*c[i]%P+P)%P*b[i]%P;
NTT(b,-1);
for(int i=n;i<tot;i++) b[i]=0;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n;
for(int i=1;i<n;i++) cin>>a[i],a[i]%=P;
h[0]=(1-a[0]+P)%P;
for(int i=1;i<n;i++) h[i]=(P-a[i])%P;
INV(b,h,n);
for(int i=0;i<n;i++) cout<<b[i]%P<<' ';
return 0;
}
```
:::
# 多项式多点求值
#### [模板题](https://www.luogu.com.cn/problem/P5050)
我们首先需要知道:
+ 对于任意多项式 $F(x)$ 和常数 $a$,$F(x)$ 除以 $x-a$ 的余数为 $F(a)$,即 $F(x)\equiv F(a)\pmod{x-a}$。
:::info[证明]
根据多项式带余除法,存在唯一的多项式 $Q(x)$ 和常数 $R$(余数多项式次数低于除式次数,$x-a$ 是一次,故余数为常数),使得
$$
F(x)=(x-a)Q(x)+R
$$
代入 $x=a$,得
$$
F(a)=(a-a)Q(a)+R=R
$$
因此余数 $R=F(a)$,命题得证。
:::
而现在我们要求 $F(x_i)$,根据上面的定理,我们将 $F(x)$ 除以 $x-x_i$ 的余数就是 $F(x_i)$。
但这样还是要跑 $m$ 遍多项式取模,不 TLE 我吃。
我们考虑构造一个多项式
$$
M(x)=\prod_{i=1}^m(x-x_i)
$$
然后计算 $F(x)\bmod M(x)$,得到一个次数小于 $m$ 的余数多项式 $R(x)$。但此时的 $R(x)$ 并不是每个 $F(x_i)$ 的值,则我们考虑分治。
我们将 $M(x)$ 拆成两份。
$$
M_L(x)=\prod_{x_i\in S_L}(x-x_i)\\
M_R(x)=\prod_{x_i\in S_R}(x-x_i)
$$
其中,$S_L,S_R$ 分别表示左右两边的点集。
因为 $R(x)=F(x)\bmod M(x)$,所以
$$
R(x)\bmod M_L(x)=F(x)\bmod M_L(x)\\
R(x)\bmod M_R(x)=F(x)\bmod M_R(x)
$$
则我们可以递归处理每个子区间:将当前多项式分别对左、右子区间多项式取模,然后得到一个次数小于左、右子区间的多项式,然后继续向下分治。
当区间只有一个点 $x_i$ 时,多项式取模的结果就是 $F(x_i)$ 的值。
这个分治的实现类比线段树,因为有 $M(x)=M_L(x)\times M_R(x)$。
时间复杂度是 $O\left(n\log^2n\right)$ 的。
然后有个常数优化,当当前处理的子区间的长度小于某个数时(作者这里取 $600$),可以直接暴力算,用时可以变为原来的 $\frac{1}{2}$。
:::success[Code]
```
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=2e5+10;
const int P=998244353,G=3;
const int L=600;
int n,m;
int f[maxn],tmpf[maxn],x[maxn];
int fr[maxn],gr[maxn];
int qr[maxn],r[maxn];
int c[maxn],d[maxn],e[maxn],tmp[maxn];
int rev[maxn],bit,tot;
int inv_gr[maxn];
int qpow(int x,int k){
int base=1;
while(k){
if(k&1) base=base*x%P;
x=x*x%P;
k>>=1;
}
return base%P;
}
const int Gi=qpow(G,P-2);
void NTT(int a[],int type){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
int wi;
if(type==1) wi=qpow(G,(P-1)/(mid<<1));
else wi=qpow(Gi,(P-1)/(mid<<1));
for(int i=0;i<tot;i+=mid*2){
int wk=1;
for(int j=0;j<mid;j++,wk=wk*wi%P){
int x=a[i+j],y=wk*a[i+j+mid]%P;
a[i+j]=(x+y)%P;
a[i+j+mid]=(x-y+P)%P;
}
}
}
if(type==-1){
int inv_n=qpow(tot,P-2);
for(int i=0;i<tot;i++) a[i]=a[i]*inv_n%P;
}
}
void INV(int b[],int a[],int n){
if(n==1){
b[0]=qpow(a[0],P-2);
return;
}
INV(b,a,(n+1)>>1);
int len=1;
bit=0;
while(len<(n<<1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) c[i]=a[i];
for(int i=n;i<tot;i++) c[i]=0;
for(int i=(n+1)/2;i<tot;i++) b[i]=0;
NTT(b,1); NTT(c,1);
for(int i=0;i<tot;i++) b[i]=(2-b[i]*c[i]%P+P)%P*b[i]%P;
NTT(b,-1);
for(int i=n;i<tot;i++) b[i]=0;
}
void MUL(int a[],int b[],int c[],int n,int m){
int len=1;
bit=0;
while(len<(n+m-1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) d[i]=a[i];
for(int i=n;i<tot;i++) d[i]=0;
for(int i=0;i<m;i++) e[i]=b[i];
for(int i=m;i<tot;i++) e[i]=0;
NTT(d,1); NTT(e,1);
for(int i=0;i<tot;i++) c[i]=d[i]*e[i]%P;
NTT(c,-1);
}
void MOD(int a[],int b[],int c[],int d[],int n,int m){
if(n<m){
for(int i=0;i<=n;i++) d[i]=a[i];
for(int i=n+1;i<m;i++) d[i]=0;
return;
}
for(int i=0;i<=n;i++) fr[i]=a[n-i];
for(int i=0;i<=m;i++) gr[i]=b[m-i];
int L=n-m+1;
for(int i=m+1;i<L;i++) gr[i]=0;
INV(inv_gr,gr,L);
MUL(fr,inv_gr,qr,n+1,L);
for(int i=0;i<L;i++) c[i]=qr[L-1-i];
MUL(c,b,tmpf,L,m+1);
for(int i=0;i<m;i++) d[i]=(a[i]-tmpf[i]+P)%P;
}
const int MAXM=64000;
vector<int> poly[4*MAXM+10];
void build(int id,int l,int r){
if(l==r){
poly[id]={(P-x[l])%P,1};
return;
}
int mid=(l+r)>>1;
build(id*2,l,mid); build(id*2+1,mid+1,r);
int lenL=poly[id*2].size(),lenR=poly[id*2+1].size();
if(lenL+lenR-1<=L){//暴力算
vector<int> res(lenL+lenR-1,0);
for(int i=0;i<lenL;i++){
for(int j=0;j<lenR;j++){
res[i+j]=(res[i+j]+poly[id*2][i]*poly[id*2+1][j])%P;
}
}
poly[id]=move(res);
}
else{
for(int i=0;i<lenL;i++) d[i]=poly[id*2][i];
for(int i=0;i<lenR;i++) e[i]=poly[id*2+1][i];
MUL(d,e,c,lenL,lenR);
poly[id].assign(c,c+lenL+lenR-1);
}
}
void mod_v(vector<int>& a,const vector<int>& b){//暴力算
int n=(int)a.size()-1,m=(int)b.size()-1;
if(n<m){
a.resize(m,0);
return;
}
for(int i=n;i>=m;i--){
if(a[i]==0) continue;
int ra=a[i];
for(int j=0;j<=m;j++) a[i-m+j]=(a[i-m+j]-1LL*ra*b[j]%P+P)%P;
}
a.resize(m);
}
void query(int id,int l,int r,vector<int> p){
if((int)p.size()<=L&&(int)poly[id].size()<=L) mod_v(p,poly[id]);
else{
int n=p.size()-1,m=poly[id].size()-1;
int *a_arr=p.data(),*b_arr=poly[id].data();
MOD(a_arr,b_arr,c,d,n,m);
p.assign(d,d+m);
}
if(l==r){
int val=0;
for(int i=(int)p.size()-1;i>=0;i--) val=(val*x[l]%P+p[i])%P;
cout<<val<<endl;
return;
}
int mid=(l+r)>>1;
query(id*2,l,mid,p); query(id*2+1,mid+1,r,p);
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n>>m;
for(int i=0;i<=n;i++) cin>>f[i],f[i]%=P;
for(int i=1;i<=m;i++) cin>>x[i],x[i]%=P;
build(1,1,m);
vector<int> p(f,f+n+1);
query(1,1,m,p);
return 0;
}
```
:::
# 多项式快速插值
#### [模板题](https://www.luogu.com.cn/problem/P5158)
咕咕咕
# 多项式三角函数
实用性为 $0$,可跳过。
#### [模板题](https://www.luogu.com.cn/problem/P5264)
首先有欧拉公式
$$
\begin{aligned}
e^{ix}&=\cos x+i\sin x \quad &(1)\\
e^{-ix}&=\cos x-i\sin x \quad &(2)
\end{aligned}
$$
联立 $(1)$ 和 $(2)$ 两式,得
$$
\begin{aligned}
\cos x&=\frac{e^{ix}+e^{-ix}}{2}\\
\sin x&=\frac{e^{ix}-e^{-ix}}{2i}
\end{aligned}
$$
然后代入多项式 $A(x)$,得
$$
\begin{aligned}
\cos(A(x))&=\frac{e^{iA(x)}+e^{-iA(x)}}{2}\\
\sin(A(x))&=\frac{e^{iA(x)}-e^{-iA(x)}}{2i}
\end{aligned}
$$
多项式 exp 解决即可。
问题来到如何求 $i$ 在模 $998244353$ 意义下的值。
注意到 $i=\omega_4=g^{\frac{p-1}{4}}$,$g$ 是 $p$ 的原根,此题中为 $3$。然后就完了。
:::success[Code]
```
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=6e6+10;
const int P=998244353,G=3;
const int I=86583718;
int n;
int fa[maxn],fb[maxn];
int c[maxn],d[maxn],e[maxn],tmp[maxn];
int rev[maxn],bit,tot;
int ln_G[maxn],tmp_G[maxn];
int qpow(int x,int k){
int base=1;
while(k){
if(k&1) base=base*x%P;
x=x*x%P;
k>>=1;
}
return base%P;
}
const int Gi=qpow(G,P-2);
void NTT(int a[],int type){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
int wi;
if(type==1) wi=qpow(G,(P-1)/(mid<<1));
else wi=qpow(Gi,(P-1)/(mid<<1));
for(int i=0;i<tot;i+=mid*2){
int wk=1;
for(int j=0;j<mid;j++,wk=wk*wi%P){
int x=a[i+j],y=wk*a[i+j+mid]%P;
a[i+j]=(x+y)%P;
a[i+j+mid]=(x-y+P)%P;
}
}
}
if(type==-1){
int inv_n=qpow(tot,P-2);
for(int i=0;i<tot;i++) a[i]=a[i]*inv_n%P;
}
}
void INV(int b[],int a[],int n){
if(n==1){
b[0]=qpow(a[0],P-2);
return;
}
INV(b,a,(n+1)>>1);
int len=1;
bit=0;
while(len<(n<<1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) c[i]=a[i];
for(int i=n;i<tot;i++) c[i]=0;
for(int i=(n+1)/2;i<tot;i++) b[i]=0;
NTT(b,1); NTT(c,1);
for(int i=0;i<tot;i++) b[i]=(2-b[i]*c[i]%P+P)%P*b[i]%P;
NTT(b,-1);
for(int i=n;i<tot;i++) b[i]=0;
}
void MUL(int a[],int b[],int c[],int n,int m){
int len=1;
bit=0;
while(len<(n+m-1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) d[i]=a[i];
for(int i=n;i<tot;i++) d[i]=0;
for(int i=0;i<m;i++) e[i]=b[i];
for(int i=m;i<tot;i++) e[i]=0;
NTT(d,1); NTT(e,1);
for(int i=0;i<tot;i++) c[i]=d[i]*e[i]%P;
NTT(c,-1);
}
void DIFF(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i-1]=i*f[i]%P;
g[n-1]=0;
}
void INTG(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i]=f[i-1]*qpow(i,P-2)%P;
g[0]=0;
}
void LN(int f[],int g[],int n){
int df[maxn],inv_f[maxn],mul_res[maxn];
for(int i=0;i<n;i++) tmp[i]=f[i];
DIFF(f,df,n);
INV(inv_f,tmp,n);
MUL(df,inv_f,mul_res,n,n);
INTG(mul_res,g,n);
}
void EXP(int g[],int f[],int n){
if(n==1){
g[0]=1;
return;
}
EXP(g,f,(n+1)>>1);
LN(g,ln_G,n);
for(int i=0;i<n;i++) tmp_G[i]=(f[i]-ln_G[i]+P)%P;
tmp_G[0]=(tmp_G[0]+1)%P;
MUL(g,tmp_G,g,n,n);
for(int i=n;i<tot;i++) g[i]=0;
}
int f1[maxn],f2[maxn];
int f3[maxn],f4[maxn];
void COS(int f[],int g[],int n){
for(int i=0;i<n;i++) f1[i]=I*f[i]%P,f2[i]=(-f1[i]+P)%P;
EXP(f3,f1,n); EXP(f4,f2,n);
for(int i=0;i<n;i++) f3[i]=(f3[i]+f4[i])%P;
for(int i=0;i<n;i++) g[i]=f3[i]*qpow(2,P-2)%P;
}
void SIN(int f[],int g[],int n){
for(int i=0;i<n;i++) f1[i]=I*f[i]%P,f2[i]=(-f1[i]+P)%P;
EXP(f3,f1,n); EXP(f4,f2,n);
for(int i=0;i<n;i++) f3[i]=(f3[i]-f4[i]+P)%P;
for(int i=0;i<n;i++) g[i]=f3[i]*qpow(I,P-2)%P*qpow(2,P-2)%P;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
int type;
cin>>n>>type;
for(int i=0;i<n;i++) cin>>fa[i],fa[i]%=P;
if(type) COS(fa,fb,n);
else SIN(fa,fb,n);
for(int i=0;i<n;i++) cout<<fb[i]<<" ";
return 0;
}
```
:::
# 多项式反三角函数
实用性为同上。
#### [模板题](https://www.luogu.com.cn/problem/P5265)
依旧求导,则有
$$
\arcsin(A(x))'=\frac{A'(x)}{\sqrt{1-A(x)^2}}\pmod{x^n}\\
\arctan(A(x))'=\frac{A'(x)}{1+A(x)^2}\pmod{x^n}
$$
多项式开根 + 乘法逆解决,最后积分回去即可。
:::success[Code]
```
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=1<<19;
const int P=998244353,G=3;
int fa[maxn],fb[maxn];
int c[maxn],d[maxn],e[maxn],tmp[maxn],tmp2[maxn];
int rev[maxn],bit,tot;
int ln_G[maxn],tmp_G[maxn];
int qpow(int x,int k){
int base=1;
while(k){
if(k&1) base=base*x%P;
x=x*x%P;
k>>=1;
}
return base%P;
}
const int Gi=qpow(G,P-2);
void NTT(int a[],int type){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
int wi;
if(type==1) wi=qpow(G,(P-1)/(mid<<1));
else wi=qpow(Gi,(P-1)/(mid<<1));
for(int i=0;i<tot;i+=mid*2){
int wk=1;
for(int j=0;j<mid;j++,wk=wk*wi%P){
int x=a[i+j],y=wk*a[i+j+mid]%P;
a[i+j]=(x+y)%P;
a[i+j+mid]=(x-y+P)%P;
}
}
}
if(type==-1){
int inv_n=qpow(tot,P-2);
for(int i=0;i<tot;i++) a[i]=a[i]*inv_n%P;
}
}
void INV(int b[],int a[],int n){
if(n==1){
b[0]=qpow(a[0],P-2);
return;
}
INV(b,a,(n+1)>>1);
int len=1;
bit=0;
while(len<(n<<1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) c[i]=a[i];
for(int i=n;i<tot;i++) c[i]=0;
for(int i=(n+1)/2;i<tot;i++) b[i]=0;
NTT(b,1); NTT(c,1);
for(int i=0;i<tot;i++) b[i]=(2-b[i]*c[i]%P+P)%P*b[i]%P;
NTT(b,-1);
for(int i=n;i<tot;i++) b[i]=0;
}
void MUL(int a[],int b[],int c[],int n,int m){
int len=1;
bit=0;
while(len<(n+m-1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) d[i]=a[i];
for(int i=n;i<tot;i++) d[i]=0;
for(int i=0;i<m;i++) e[i]=b[i];
for(int i=m;i<tot;i++) e[i]=0;
NTT(d,1); NTT(e,1);
for(int i=0;i<tot;i++) c[i]=d[i]*e[i]%P;
NTT(c,-1);
}
void DIFF(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i-1]=i*f[i]%P;
g[n-1]=0;
}
void INTG(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i]=f[i-1]*qpow(i,P-2)%P;
g[0]=0;
}
void LN(int f[],int g[],int n){
static int df[maxn],inv_f[maxn],mul_res[maxn];
for(int i=0;i<n;i++) tmp[i]=f[i];
DIFF(f,df,n);
INV(inv_f,tmp,n);
MUL(df,inv_f,mul_res,n,n);
INTG(mul_res,g,n);
}
void EXP(int g[],int f[],int n){
if(n==1){
g[0]=1;
return;
}
EXP(g,f,(n+1)>>1);
for(int i=0;i<n;i++) tmp_G[i]=g[i];
LN(tmp_G,ln_G,n);
for(int i=0;i<n;i++) tmp_G[i]=(f[i]-ln_G[i]+P)%P;
tmp_G[0]=(tmp_G[0]+1)%P;
MUL(g,tmp_G,g,n,n);
for(int i=n;i<tot;i++) g[i]=0;
}
void SQRT(int f[],int g[],int n){
LN(f,g,n);
for(int i=0;i<n;i++) f[i]=g[i]*qpow(2,P-2)%P;
EXP(g,f,n);
}
int f1[maxn],f2[maxn],f3[maxn];
void ASIN(int f[],int g[],int n){
DIFF(f,f1,n);
MUL(f,f,f2,n,n);
for(int i=0;i<n;i++){
if(i==0) f2[i]=(1-f2[i]+P)%P;
else f2[i]=(-f2[i]+P)%P;
}
SQRT(f2,f3,n); INV(f2,f3,n);
MUL(f1,f2,f3,n,n);
INTG(f3,g,n);
}
void ATAN(int f[],int g[],int n){
DIFF(f,f1,n);
MUL(f,f,f2,n,n);
f2[0]=(1+f2[0])%P;
INV(f3,f2,2*n);
MUL(f1,f3,f2,n,2*n);
INTG(f2,g,n);
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
int n,type;
cin>>n>>type;
for(int i=0;i<n;i++) cin>>fa[i],fa[i]%=P;
if(!type){
ASIN(fa,fb,n);
for(int i=0;i<n;i++) cout<<fb[i]%P<<' ';
}
else{
ATAN(fa,fb,n);
for(int i=0;i<n;i++) cout<<fb[i]%P<<' ';
}
return 0;
}
```
:::
然后来讲一下如何上面两个式子怎么来的。
:::info[$\arcsin$]
设 $y=\arcsin x$,则 $x=\sin y$,其中 $y\in[-\frac{\pi}{2},\frac{\pi}{2}]$。
两边对 $x$ 求导,则有
$$
1=\cos y\times y'
$$
那么
$$
y'=\frac{1}{\cos y}=\frac{1}{\sqrt{1-\sin^2 y}}=\frac{1}{\sqrt{1-x^2}}
$$
因此
$$
\arcsin' x=\frac{1}{\sqrt{1-x^2}}
$$
而 $\arcsin(A(x))$ 是个复合函数,则
$$
\begin{aligned}
\arcsin(A(x))'&=\arcsin'(A(x))\times A(x)'\\
&=\frac{1}{\sqrt{1-A(x)^2}}\times A(x)'\\
&=\frac{A(x)'}{\sqrt{1-A(x)^2}}
\end{aligned}
$$
:::
:::info[$\arctan$]
设 $y=\arctan x$,则 $x=\tan y$,$y\in(-\frac{\pi}{2},\frac{\pi}{2})$。
两边对 $x$ 求导,则有
$$
1=\sec^2 y\times\frac{dy}{dx}
$$
由于 $\sec^2y=1+\tan^2y=1+x^2$,所以
$$
\frac{dy}{dx}=\frac{1}{1+x^2}
$$
因此
$$
\arctan' x=\frac{1}{1+x^2}
$$
而 $\arctan(A(x))$ 是个复合函数,则
$$
\begin{aligned}
\arctan(A(x))'&=\arctan'A(x)\times A(x)'\\
&=\frac{1}{1+A(x)^2}\times A(x)'\\
&=\frac{A(x)'}{1+A(x)^2}
\end{aligned}
$$
:::
# 应用
## 1. 求卷积
我们知道 FFT 和 NTT 可以用来求 $A(x)B(x)$ 的各项系数。
则考虑把 $A(x)B(x)$ 展开。
这里令 $n>m$,并将 $A(x),B(x)$ 用 $0$ 补至 $2n$ 次项系数。
则有:
$$
\begin{aligned}
A(x)B(x)&=\sum_{i=0}^na_ix^i\sum_{j=0}^nb_jx^j\\
&=\sum_{i=0}^n\sum_{j=0}^na_ib_jx^{i+j}\\
&=\sum_{i=0}^n\sum_{j=0}^n\sum_{k=0}^{2n}[i+j=k]a_ib_jx^k\\
&=\sum_{k=0}^{2n}\sum_{i=0}^n\sum_{j=0}^{n}[i+j=k]a_ib_jx^k\\
&=\sum_{k=0}^{2n}\sum_{i=0}^k a_i b_{k-i}x^k\\
&=\sum_{k=0}^{2n}\left(\sum_{i=0}^ka_ib_{k-i}\right)x^k\\
\end{aligned}
$$
则最后求得的系数为 $\sum_{i=0}^ka_ib_{k-i}$。
**恭喜你,发现了多项式乘法的第一个作用:求卷积!**
:::info[求解套路]
若底数之和为定值时,直接多项式乘法即可。
若底数之差为定值时,不妨以 $\sum_{i=0}^{n-k}a_{i}b_{k+i}$ 为例。
令 $b'_i=b_{n-i}
则有
\begin{aligned}
\sum_{i=0}^{n-k}a_{i}b_{k+i}
&=\sum_{i=0}^{n-k}a_ib'_{n-(k+i)}\\
&=\sum_{i=0}^{n-k}a_ib'_{n-k-i}
\end{aligned}
又令 t=n-k
则有
原式=\sum_{i=0}^ta_{i}b'_{t-i}
至此,式子转换为底数之和的形式。
:::
例题一:[ZJOI2014] 力
容易把题目所求转换为
\begin{aligned}
E_j&=\sum_{i=1}^{j-1} \frac{q_i}{(i - j)^2}~-~\sum_{i=j+1}^{n} \frac{q_i}{(i - j)^2}\\
&=\sum_{i=1}^{j} \frac{q_i}{(i - j)^2}~-~\sum_{i=j}^{n} \frac{q_i}{(i - j)^2}\\
\end{aligned}
不妨令 f_i=q_i,g_i=\frac{1}{i^2} 。
则有
\begin{aligned}
\sum_{i=1}^{j} \frac{q_i}{(i - j)^2}~-~\sum_{i=j}^{n} \frac{q_i}{(i - j)^2}\\
=\sum_{i=1}^j f_ig_{j-i}-\sum_{i=j}^{n}f_ig_{i-j}
\end{aligned}
不妨又令 f_0=0,g_0=0 。
则有
\begin{aligned}
\sum_{i=1}^j f_ig_{j-i}-\sum_{i=j}^{n}f_ig_{i-j}\\
=\sum_{i=0}^{j}f_ig_{j-i}-\sum_{i=j}^{n}f_ig_{i-j}
\end{aligned}
前面部分已经是卷积形式了,先丢一边推后面部分。
注意到 \sum_{i=j}^n f_ig_{i-j}=\sum_{i=0}^{n-j}f_{i+j}g_i 。
再根据我们的求解套路,令 t=n-j,f'(i)=f(n-i) 。
则有
\sum_{i=0}^{n-j}f_{i+j}g_i=\sum_{i=0}^{t}f'_{t-i}g_i
总合一下:
E_j=\sum_{i=0}^{j}f_ig_{j-i}-\sum_{i=0}^{t}f'_{t-i}g_i
至此,我们将所有式子转换为卷积形式,FFT 求解即可。
:::success[Code]
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=4e5+10;
const double pi=acos(-1);
int n;
struct node{
double x,y;
node operator+ (const node& t)const{
return {x+t.x,y+t.y};
}
node operator- (const node &t)const{
return {x-t.x,y-t.y};
}
node operator* (const node &t)const{
return {x*t.x-y*t.y,x*t.y+y*t.x};
}
}f1[maxn],f2[maxn],g[maxn];
node tmp1[maxn],tmp2[maxn];
int rev[maxn],bit,tot;
void FFT(node a[],int inv){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
auto w1=node({cos(pi/mid),inv*sin(pi/mid)});
for(int i=0;i<tot;i+=mid*2){
auto wk=node({1,0});
for(int j=0;j<mid;j++,wk=wk*w1){
auto x=a[i+j],y=wk*a[i+j+mid];
a[i+j]=x+y,a[i+j+mid]=x-y;
}
}
}
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n;
for(int i=1;i<=n;i++) cin>>f1[i].x,g[i].x=1.0/(i*i);
for(int i=1;i<=n;i++) f2[i]=f1[n-i+1];
while((1<<bit)<2*n+1) bit++;
tot=1<<bit;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
FFT(f1,1); FFT(g,1); FFT(f2,1);
for(int i=0;i<tot;i++) tmp1[i]=f1[i]*g[i],tmp2[i]=g[i]*f2[i];
FFT(tmp1,-1); FFT(tmp2,-1);
for(int i=1;i<=n;i++) printf("%.3lf\n",tmp1[i].x/tot-tmp2[n-i+1].x/tot);
return 0;
}
:::
例题二:[AHOI2017 / HNOI2017] 礼物
:::info[题意]{open}
给定两个长度为 n 的环 x,y ,环可以旋转,不可翻转,求 \sum_{i=1}^n(x_i-y_i+c)^2 的最小值,c 为任意一个非负整数。
:::
考虑把平方展开,则有
\begin{aligned}
\sum_{i=1}^n\left(x_i-y_i+c\right)^2
&=\sum_{i=1}^n\left(x_i^2-x_iy_i+cx_i-x_iy_i+y_i^2-cy_i+cx_i-cy_i+c^2\right)\\
&=\sum_{i=1}^n\left(x_i^2-2x_iy_i+y_i^2+(2x_i-2y_i)c+c^2\right)\\
&=\sum_{i=1}^nx_i^2+\sum_{i=1}^ny_i^2-2\sum_{i=1}^nx_iy_i+\sum_{i=1}^n\left(c^2+2\left(x_i-y_i\right)c\right)
\end{aligned}
注意到前面两项是定值,后面一项可用二次函数求最值解决。
问题来到如何解决 \sum_{i=1}^nx_iy_i 及旋转问题。
可以看做求解
\max_k \sum_{i=1}^nx_iy_{i+k}
先运用求解套路。
即有
\max_k\sum_{i=1}^nx'_{n-i}y_{i+k}
其中 x'_i=x_{n-i} 。
再考虑断环为链,即 y_{n+i}=y_i ,然后与 x' 做一遍 FFT,最后枚举最大值即可。
:::success[Code]
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=3e5+10;
const double pi=acos(-1);
int n,m;
int ans,sumx,sumy;
struct node{
double x,y;
node operator+ (const node&t)const{
return {x+t.x,y+t.y};
}
node operator- (const node&t)const{
return {x-t.x,y-t.y};
}
node operator* (const node&t)const{
return {x*t.x-y*t.y,x*t.y+y*t.x};
}
}a1[maxn],a2[maxn],b[maxn];
int rev[maxn],bit,tot;
void fft(node a[],int inv){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
auto w1=node({cos(pi/mid),inv*sin(pi/mid)});
for(int i=0;i<tot;i+=mid*2){
auto wk=node({1,0});
for(int j=0;j<mid;j++,wk=wk*w1){
auto x=a[i+j],y=a[i+j+mid]*wk;
a[i+j]=x+y; a[i+j+mid]=x-y;
}
}
}
if(inv==-1){
for(int i=0;i<tot;i++) a[i].x/=tot;
}
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n>>m;
for(int i=0;i<n;i++){
cin>>a1[i].x;
ans+=a1[i].x*a1[i].x;
sumx+=a1[i].x;
}
for(int i=0;i<n;i++){
cin>>b[i].x;
b[n+i].x=b[i].x;
ans+=b[i].x*b[i].x;
sumy+=b[i].x;
}
for(int i=0;i<n;i++) a2[i].x=a1[n-i-1].x;
int len=3*n;
while((1<<bit) < len+1) bit++;
tot=1<<bit;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
fft(a2,1); fft(b,1);
for(int i=0;i<tot;i++) b[i]=b[i]*a2[i];
fft(b,-1);
double maxx=LLONG_MIN;
for(int k=0;k<n;k++) maxx=max(maxx,b[n-1+k].x);
int c=llround((sumy-sumx)*1.0/n);
cout<<ans-2*llround(maxx)+n*c*c+2*c*(sumx-sumy);
return 0;
}
:::
2. 优化 DP
此类问题一般通过 DP 转移方程构建生成函数,通过多项式解决。
没有生成函数基础的先看这里。
例题:付公主的背包
对于一个体积为 v 的物品,容易构建出如下生成函数:
A(x)=\sum_{i=0}^\infty x^i[i\bmod v=0]=\sum_{k=0}^{\infty}x^{kv}
可以看做消耗 kv 的体积去装 k 个这样的物品。
那么根据乘法原理可得,答案即为这 n 个生成函数的积,直接做的话时间复杂度是 O(nm\log m) 的。
我们考虑把生成函数写成封闭形式,则有
\sum_{k=0}^{\infty}x^{kv}=\frac{1}{1-x^v}
那答案则为
F(x)=\prod_{i=1}^n\frac{1}{1-x^{v_i}}
然后有个很经典的技巧:将多项式取对数,可以将乘法转为加法。令 G(x)=\ln F(x) ,则有
G(x)=-\sum_{i=1}^n \ln(1-x^{v_i})
然后利用 \ln(1-x^{v_i})=-\sum_{k=1}^{\infty}\frac{x^{kv_i}}{k} ,得到
G(x)=\sum_{i=1}^n\sum_{k=1}^{\infty}\frac{x^{kv_i}}{k}
那么对于一个体积 v ,统计它的出现次数 c ,然后枚举 k ,给第 kv 项上加上 \frac{c}{k} 。这样,我们就得到了 G(x) 每一项的系数,再对 G(x) 求一次 \exp 即可。
:::success[Code]
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=4e5+10;
const int P=998244353,G=3;
int n,m;
int fa[maxn],fb[maxn];
int c[maxn],d[maxn],e[maxn],tmp[maxn];
int rev[maxn],bit,tot;
int ln_G[maxn],tmp_G[maxn];
int qpow(int x,int k){
int base=1;
while(k){
if(k&1) base=base*x%P;
x=x*x%P;
k>>=1;
}
return base%P;
}
const int Gi=qpow(G,P-2);
void NTT(int a[],int type){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
int wi;
if(type==1) wi=qpow(G,(P-1)/(mid<<1));
else wi=qpow(Gi,(P-1)/(mid<<1));
for(int i=0;i<tot;i+=mid*2){
int wk=1;
for(int j=0;j<mid;j++,wk=wk*wi%P){
int x=a[i+j],y=wk*a[i+j+mid]%P;
a[i+j]=(x+y)%P;
a[i+j+mid]=(x-y+P)%P;
}
}
}
if(type==-1){
int inv_n=qpow(tot,P-2);
for(int i=0;i<tot;i++) a[i]=a[i]*inv_n%P;
}
}
void INV(int b[],int a[],int n){
if(n==1){
b[0]=qpow(a[0],P-2);
return;
}
INV(b,a,(n+1)>>1);
int len=1;
bit=0;
while(len<(n<<1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) c[i]=a[i];
for(int i=n;i<tot;i++) c[i]=0;
for(int i=(n+1)/2;i<tot;i++) b[i]=0;
NTT(b,1); NTT(c,1);
for(int i=0;i<tot;i++) b[i]=(2-b[i]*c[i]%P+P)%P*b[i]%P;
NTT(b,-1);
for(int i=n;i<tot;i++) b[i]=0;
}
void MUL(int a[],int b[],int c[],int n,int m){
int len=1;
bit=0;
while(len<(n+m-1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) d[i]=a[i];
for(int i=n;i<tot;i++) d[i]=0;
for(int i=0;i<m;i++) e[i]=b[i];
for(int i=m;i<tot;i++) e[i]=0;
NTT(d,1); NTT(e,1);
for(int i=0;i<tot;i++) c[i]=d[i]*e[i]%P;
NTT(c,-1);
}
void DIFF(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i-1]=i*f[i]%P;
g[n-1]=0;
}
void INTG(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i]=f[i-1]*qpow(i,P-2)%P;
g[0]=0;
}
int df[maxn],inv_f[maxn],mul_res[maxn];
void LN(int f[],int g[],int n){
for(int i=0;i<n;i++) tmp[i]=f[i];
DIFF(f,df,n);
INV(inv_f,tmp,n);
MUL(df,inv_f,mul_res,n,n);
INTG(mul_res,g,n);
}
void EXP(int g[],int f[],int n){
if(n==1){
g[0]=1;
return;
}
EXP(g,f,(n+1)>>1);
LN(g,ln_G,n);
for(int i=0;i<n;i++) tmp_G[i]=(f[i]-ln_G[i]+P)%P;
tmp_G[0]=(tmp_G[0]+1)%P;
MUL(g,tmp_G,g,n,n);
for(int i=n;i<tot;i++) g[i]=0;
}
int inv[maxn];
void init(int n){
inv[0]=inv[1]=1;
for(int i=2;i<n;i++) inv[i]=inv[P%i]*(P-P/i)%P;
}
int t[maxn];
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n>>m;
init(m+5);
for(int i=1;i<=n;i++){
int x;
cin>>x; t[x]++;
}
for(int i=1;i<=m;i++){
if(t[i]){
for(int j=i;j<=m;j+=i) fa[j]=(fa[j]+t[i]*inv[j/i])%P;
}
}
EXP(fb,fa,m+1);
for(int i=1;i<=m;i++) cout<<fb[i]%P<<endl;
return 0;
}
:::
总结
本文讲得都是一些基础的多项式操作以及一些经典例题,多项式快速差值后面有时间会考虑补上。作者的代码常数较好,请放心食用。
参考文献
attack-【模板】多项式乘法(NTT)
Prean - 模板 P4245 题解
wsy_I 的 CRT 学习笔记(不方便公开)
KAMIYA_KINA - P4238 【模板】多项式乘法逆 题解
Kewth - 题解 P5491 【模板】二次剩余
《高等数学》第七版上册 高等教育出版社