基于倍增维护多项式点值的计算
NaCly_Fish · · 个人记录
有一些问题,例如阶乘、阶乘之和,可以很容易地以
于是就有了 min_25 提出的,一种基于倍增维护多项式点值的算法,可以将时间复杂度优化到
算是这类问题中比较基础的一个。
下面介绍计算 完全平方数 的阶乘的方法(如果不是完全平方数,暴力乘一段即可,这部分复杂度不超过
- 由
f_d(sk) \ (k \in[0,2d]) 计算f_d{(sk+d)} \ (k \in[0,2d])
而且这两个子问题有着类似的形式(多项式平移),用 这题的方法 就可以
相比之下,由
参考代码(可以计算
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#define N 262147
#define ll long long
#define reg register
#define p 998244353
using namespace std;
inline int power(int a,int t){
int res = 1;
while(t){
if(t&1) res = (ll)res*a%p;
a = (ll)a*a%p;
t >>= 1;
}
return res;
}
int fac[N],ifac[N],rt[N],rev[N],facm[N];
int siz;
void init(int n){
int w,lim = 1;
while(lim<=n) lim <<= 1,++siz;
for(reg int i=1;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1));
w = power(3,(p-1)>>siz);
fac[0] = fac[1] = ifac[0] = ifac[1] = rt[lim>>1] = 1;
for(reg int i=lim>>1|1;i!=lim;++i) rt[i] = (ll)rt[i-1]*w%p;
for(reg int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1];
for(reg int i=2;i<=n;++i) ifac[i] = fac[i] = (ll)fac[i-1]*i%p;
ifac[n] = power(fac[n],p-2);
for(reg int i=n-1;i;--i) ifac[i] = (ll)ifac[i+1]*(i+1)%p;
}
inline void NTT(int *f,int type,int lim){
if(type==-1) reverse(f+1,f+lim);
static unsigned long long a[N];
reg int x,shift = siz-__builtin_ctz(lim);
for(reg int i=0;i!=lim;++i) a[rev[i]>>shift] = f[i];
for(reg int mid=1;mid!=lim;mid<<=1)
for(reg int j=0;j!=lim;j+=(mid<<1))
for(reg int k=0;k!=mid;++k){
x = a[j|k|mid]*rt[mid|k]%p;
a[j|k|mid] = a[j|k]+p-x;
a[j|k] += x;
}
for(reg int i=0;i!=lim;++i) f[i] = a[i]%p;
if(type==1) return;
x = p-(p-1)/lim;
for(reg int i=0;i!=lim;++i) f[i] = (ll)f[i]*x%p;
}
void lagrange(const int *F,int n,int m,int *R){ // calculate f(m),f(m+1) ... f(m+n) with f(0),f(1) ... f(n)
static int f[N],G[N],pre[N],suf[N],inv[N];
memcpy(f,F,(n+1)<<2);
int tmp,k = n<<1|1,mul,lim = 1<<(32-__builtin_clz(n*3));
if(m<=n) tmp = n-m+1,m = n+1;
else tmp = 0;
facm[0] = 1;
for(reg int i=0;i<=n;++i) facm[0] = (ll)facm[0]*(m-n+i)%p;
pre[0] = suf[k+1] = 1;
for(reg int i=1;i<=k;++i) pre[i] = (ll)pre[i-1]*(m-n+i-1)%p;
for(reg int i=k;i;--i) suf[i] = (ll)suf[i+1]*(m-n+i-1)%p;
mul = power(pre[k],p-2);
for(reg int i=1;i<=k;++i) inv[i] = (ll)mul*pre[i-1]%p*suf[i+1]%p;
for(reg int i=1;i<=n;++i) facm[i] = (ll)facm[i-1]*(m+i)%p*inv[i]%p;
for(reg int i=0;i<=n;++i){
f[i] = (ll)f[i]*ifac[i]%p*ifac[n-i]%p;
if((n-i)&1) f[i] = p-f[i];
}
for(reg int i=0;i!=k;++i) G[i] = inv[i+1];
memset(f+n+1,0,(lim-n)<<2);
memset(G+k,0,(lim-k+1)<<2);
NTT(f,1,lim),NTT(G,1,lim);
for(reg int i=0;i!=lim;++i) f[i] = (ll)f[i]*G[i]%p;
NTT(f,-1,lim);
memcpy(R,F+n-tmp+1,tmp<<2);
for(reg int i=tmp;i<=n;++i) R[i] = (ll)f[i+n-tmp]*facm[i-tmp]%p;
}
int solve(int n){
static int f[N],fd[N],st[30];
memset(f,0,sizeof(f));
int top = 0,s = n;
while(n){
st[++top] = n;
n >>= 1;
}
n = st[top--];
f[0] = 1,f[1] = s+1;
while(top--){
lagrange(f,n,n+1,f+n+1);
f[n<<1|1] = 0;
int tmp = (ll)n*power(s,p-2)%p;
lagrange(f,n<<1,tmp,fd);
for(reg int i=0;i<=(n<<1);++i) f[i] = (ll)f[i]*fd[i]%p;
n <<= 1;
if(!(st[top+1]&1)) continue;
for(reg int i=0;i<=n;++i) f[i] = (ll)f[i]*(s*i+n+1)%p;
f[n+1] = 1;
for(reg int i=1;i<=n+1;++i) f[n+1] = (ll)f[n+1]*(s*(n+1)+i)%p;
++n;
}
int res = f[0];
for(reg int i=1;i!=s;++i) res = (ll)res*f[i]%p;
return res;
}
inline int factorial(int n){
int k = sqrt(n),res;
res = solve(k);
for(reg int i=k*k+1;i<=n;++i) res = (ll)res*i%p;
return res;
}
int main(){
init(150000);
int T,n;
scanf("%d",&T);
while(T--){
scanf("%d",&n);
printf("%d\n",factorial(n));
}
return 0;
}
ps:还可以使用威尔逊定理优化常数
设
容易发现它和答案的关系
沿用上面的方法,快速处理出这些阶乘的值,这里就不赘述了。
考虑如何由
(照着上面推一推就能出这个式子)
这样一来,三个多项式的点值都能倍增计算。
而且巧合的是,在此过程中计算的
最后会多出来一小段,用组合数递推计算即可。
时间复杂度
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#define N 262147
#define ll long long
#define reg register
#define p 998244353
#define add(x,y) (x+y>=p?x+y-p:x+y)
#define dec(x,y) (x<y?x-y+p:x-y)
using namespace std;
inline int power(int a,int t){
int res = 1;
while(t){
if(t&1) res = (ll)res*a%p;
a = (ll)a*a%p;
t >>= 1;
}
return res;
}
int fac[N],ifac[N],rt[N],rev[N],facm[N],inv[N];
int siz;
void init(int n){
int w,lim = 1;
while(lim<=n) lim <<= 1,++siz;
for(reg int i=1;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1));
w = power(3,(p-1)>>siz);
inv[1] = fac[0] = fac[1] = ifac[0] = ifac[1] = rt[lim>>1] = 1;
for(reg int i=lim>>1|1;i!=lim;++i) rt[i] = (ll)rt[i-1]*w%p;
for(reg int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1];
for(reg int i=2;i<=n;++i) fac[i] = (ll)fac[i-1]*i%p;
ifac[n] = power(fac[n],p-2);
for(reg int i=n-1;i;--i) ifac[i] = (ll)ifac[i+1]*(i+1)%p;
for(reg int i=2;i<=n;++i) inv[i] = (ll)(p-p/i)*inv[p%i]%p;
}
inline void NTT(int *f,int type,int lim){
if(type==-1) reverse(f+1,f+lim);
static unsigned long long a[N];
reg int x,shift = siz-__builtin_ctz(lim);
for(reg int i=0;i!=lim;++i) a[rev[i]>>shift] = f[i];
for(reg int mid=1;mid!=lim;mid<<=1)
for(reg int j=0;j!=lim;j+=(mid<<1))
for(reg int k=0;k!=mid;++k){
x = a[j|k|mid]*rt[mid|k]%p;
a[j|k|mid] = a[j|k]+p-x;
a[j|k] += x;
}
for(reg int i=0;i!=lim;++i) f[i] = a[i]%p;
if(type==1) return;
x = p-(p-1)/lim;
for(reg int i=0;i!=lim;++i) f[i] = (ll)f[i]*x%p;
}
void lagrange(const int *F,int n,int m,int *R){
static int f[N],G[N],pre[N],suf[N],inv_[N];
memcpy(f,F,(n+1)<<2);
int tmp,k = n<<1|1,mul,lim = 1<<(32-__builtin_clz(n*3));
if(m<=n) tmp = n-m+1,m = n+1;
else tmp = 0;
facm[0] = 1;
for(reg int i=0;i<=n;++i) facm[0] = (ll)facm[0]*(m-n+i)%p;
pre[0] = suf[k+1] = 1;
for(reg int i=1;i<=k;++i) pre[i] = (ll)pre[i-1]*(m-n+i-1)%p;
for(reg int i=k;i;--i) suf[i] = (ll)suf[i+1]*(m-n+i-1)%p;
mul = power(pre[k],p-2);
for(reg int i=1;i<=k;++i) inv_[i] = (ll)mul*pre[i-1]%p*suf[i+1]%p;
for(reg int i=1;i<=n;++i) facm[i] = (ll)facm[i-1]*(m+i)%p*inv_[i]%p;
for(reg int i=0;i<=n;++i){
f[i] = (ll)f[i]*ifac[i]%p*ifac[n-i]%p;
if((n-i)&1) f[i] = p-f[i];
}
for(reg int i=0;i!=k;++i) G[i] = inv_[i+1];
memset(f+n+1,0,(lim-n)<<2);
memset(G+k,0,(lim-k+1)<<2);
NTT(f,1,lim),NTT(G,1,lim);
for(reg int i=0;i!=lim;++i) f[i] = (ll)f[i]*G[i]%p;
NTT(f,-1,lim);
memcpy(R,F+n-tmp+1,tmp<<2);
for(reg int i=tmp;i<=n;++i) R[i] = (ll)f[i+n-tmp]*facm[i-tmp]%p;
}
inline int value(const int *f,int k,int n){ //单点求值
static int pre[N],suf[N];
pre[0] = n,suf[k+1] = 1;
for(reg int i=1;i<=k;++i) pre[i] = (ll)pre[i-1]*(n-i)%p;
for(reg int i=k;i;--i) suf[i] = (ll)suf[i+1]*(n-i)%p;
int g,res = (ll)ifac[k]*f[0]%p*suf[1]%p;
if(k&1) res = p-res;
for(reg int i=1;i<=k;++i){
g = (ll)ifac[i]*ifac[k-i]%p*pre[i-1]%p*suf[i+1]%p*f[i]%p;
res = (k-i)&1?dec(res,g):add(res,g);
}
return res;
}
inline void getinv(const int *f,int n,int *R){ //求逆元,卡常用
static int pre[N],suf[N];
pre[0] = suf[n+1] = 1;
for(reg int i=1;i<=n;++i) pre[i] = (ll)pre[i-1]*f[i]%p;
for(reg int i=n;i;--i) suf[i] = (ll)suf[i+1]*f[i]%p;
int mul = power(pre[n],p-2);
for(reg int i=1;i<=n;++i) R[i] = (ll)mul*pre[i-1]%p*suf[i+1]%p;
R[0] = power(f[0],p-2);
}
int solve(int n,int m){
int s = sqrt(m),top = 0,d,x,invs;
invs = power(s,p-2);
static int f[N],g[N],h[N],fd[N],gd[N],hd[N],st[30];
memset(f,0,sizeof(f));
memset(g,0,sizeof(g));
memset(h,0,sizeof(h));
while(s){
st[++top] = s;
s >>= 1;
}
s = st[1],d = st[top--];
f[0] = g[0] = 1;
f[1] = g[1] = s+1;
h[0] = n,h[1] = n-s;
while(top--){
lagrange(f,d,d+1,f+d+1);
lagrange(g,d,d+1,g+d+1);
lagrange(h,d,d+1,h+d+1);
f[d<<1|1] = g[d<<1|1] = h[d<<1|1] = 0;
int offset = (ll)d*invs%p;
lagrange(f,d<<1,offset,fd);
lagrange(g,d<<1,offset,gd);
lagrange(h,d<<1,offset,hd);
for(reg int i=0;i<=(d<<1);++i){
f[i] = ((ll)f[i]*gd[i]+(ll)fd[i]*h[i])%p;
g[i] = (ll)g[i]*gd[i]%p;
h[i] = (ll)h[i]*hd[i]%p;
}
d <<= 1;
if(!(st[top+1]&1)) continue;
f[d+1] = value(f,d,d+1);
g[d+1] = value(g,d,d+1);
h[d+1] = value(h,d,d+1);
for(reg int i=0;i<=d+1;++i){
x = s*i;
f[i] = (ll)(f[i]+h[i])*(x+d+1)%p;
g[i] = (ll)g[i]*(x+d+1)%p;
h[i] = (ll)h[i]*(n-x-d)%p;
}
++d;
}
int res = 0,mul = 1;
getinv(g,s-1,g);
for(reg int i=0;i!=s;++i){
mul = (ll)mul*g[i]%p;
res = (res+(ll)mul*f[i])%p;
mul = (ll)mul*h[i]%p;
}
res = add(res,mul);
s *= s;
if(m==s) return res;
for(reg int i=s+1;i<=m;++i) g[i-s] = i;
getinv(g,m-s,g);
for(reg int i=s+1;i<=m;++i){
mul = (ll)mul*(n-i+1)%p*g[i-s]%p;
res = add(res,mul);
}
return res;
}
inline int binom_sum(int n,int m){
if(m==0) return 1;
if(m==n) return power(2,n);
int res;
if(m<=100000){
int mul = res = 1;
for(reg int i=1;i<=m;++i){
mul = (ll)mul*(n-i+1)%p*inv[i]%p;
res = add(res,mul);
}
return res;
}
if(m>n-m){
int res = solve(n,n-m-1);
return dec(power(2,n),res);
}
return solve(n,m);
}
int main(){
init(150000);
int T,n,m;
scanf("%d",&T);
while(T--){
scanf("%d%d",&n,&m);
printf("%d\n",binom_sum(n,m));
}
return 0;
}
还是一样的套路,设
但这个形式实在不好搞,考虑拆成两部分
这样就有
注意
时间复杂度还是
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#define N 262147
#define ll long long
#define reg register
#define p 998244353
#define add(x,y) (x+y>=p?x+y-p:x+y)
#define dec(x,y) (x<y?x-y+p:x-y)
using namespace std;
inline int power(int a,int t){
int res = 1;
while(t){
if(t&1) res = (ll)res*a%p;
a = (ll)a*a%p;
t >>= 1;
}
return res;
}
int fac[N],ifac[N],rt[N],rev[N],facm[N],inv[N];
int siz;
void init(int n){
int w,lim = 1;
while(lim<=n) lim <<= 1,++siz;
for(reg int i=1;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1));
w = power(3,(p-1)>>siz);
inv[1] = fac[0] = fac[1] = ifac[0] = ifac[1] = rt[lim>>1] = 1;
for(reg int i=lim>>1|1;i!=lim;++i) rt[i] = (ll)rt[i-1]*w%p;
for(reg int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1];
for(reg int i=2;i<=n;++i) fac[i] = (ll)fac[i-1]*i%p;
ifac[n] = power(fac[n],p-2);
for(reg int i=n-1;i;--i) ifac[i] = (ll)ifac[i+1]*(i+1)%p;
for(reg int i=2;i<=n;++i) inv[i] = (ll)(p-p/i)*inv[p%i]%p;
}
inline void NTT(int *f,int type,int lim){
if(type==-1) reverse(f+1,f+lim);
static unsigned long long a[N];
reg int x,shift = siz-__builtin_ctz(lim);
for(reg int i=0;i!=lim;++i) a[rev[i]>>shift] = f[i];
for(reg int mid=1;mid!=lim;mid<<=1)
for(reg int j=0;j!=lim;j+=(mid<<1))
for(reg int k=0;k!=mid;++k){
x = a[j|k|mid]*rt[mid|k]%p;
a[j|k|mid] = a[j|k]+p-x;
a[j|k] += x;
}
for(reg int i=0;i!=lim;++i) f[i] = a[i]%p;
if(type==1) return;
x = p-(p-1)/lim;
for(reg int i=0;i!=lim;++i) f[i] = (ll)f[i]*x%p;
}
void lagrange(const int *F,int n,int m,int *R){
static int f[N],G[N],pre[N],suf[N],inv_[N];
memcpy(f,F,(n+1)<<2);
int tmp,k = n<<1|1,mul,lim = 1<<(32-__builtin_clz(n*3));
if(m<=n) tmp = n-m+1,m = n+1;
else tmp = 0;
facm[0] = 1;
for(reg int i=0;i<=n;++i) facm[0] = (ll)facm[0]*(m-n+i)%p;
pre[0] = suf[k+1] = 1;
for(reg int i=1;i<=k;++i) pre[i] = (ll)pre[i-1]*(m-n+i-1)%p;
for(reg int i=k;i;--i) suf[i] = (ll)suf[i+1]*(m-n+i-1)%p;
mul = power(pre[k],p-2);
for(reg int i=1;i<=k;++i) inv_[i] = (ll)mul*pre[i-1]%p*suf[i+1]%p;
for(reg int i=1;i<=n;++i) facm[i] = (ll)facm[i-1]*(m+i)%p*inv_[i]%p;
for(reg int i=0;i<=n;++i){
f[i] = (ll)f[i]*ifac[i]%p*ifac[n-i]%p;
if((n-i)&1) f[i] = p-f[i];
}
for(reg int i=0;i!=k;++i) G[i] = inv_[i+1];
memset(f+n+1,0,(lim-n)<<2);
memset(G+k,0,(lim-k+1)<<2);
NTT(f,1,lim),NTT(G,1,lim);
for(reg int i=0;i!=lim;++i) f[i] = (ll)f[i]*G[i]%p;
NTT(f,-1,lim);
memcpy(R,F+n-tmp+1,tmp<<2);
for(reg int i=tmp;i<=n;++i) R[i] = (ll)f[i+n-tmp]*facm[i-tmp]%p;
}
inline int value(const int *f,int k,int n){
static int pre[N],suf[N];
pre[0] = n,suf[k+1] = 1;
for(reg int i=1;i<=k;++i) pre[i] = (ll)pre[i-1]*(n-i)%p;
for(reg int i=k;i;--i) suf[i] = (ll)suf[i+1]*(n-i)%p;
int g,res = (ll)ifac[k]*f[0]%p*suf[1]%p;
if(k&1) res = p-res;
for(reg int i=1;i<=k;++i){
g = (ll)ifac[i]*ifac[k-i]%p*pre[i-1]%p*suf[i+1]%p*f[i]%p;
res = (k-i)&1?dec(res,g):add(res,g);
}
return res;
}
inline void getinv(const int *f,int n,int *R){
static int pre[N],suf[N];
pre[0] = suf[n+1] = 1;
for(reg int i=1;i<=n;++i) pre[i] = (ll)pre[i-1]*f[i]%p;
for(reg int i=n;i;--i) suf[i] = (ll)suf[i+1]*f[i]%p;
int mul = power(pre[n],p-2);
for(reg int i=1;i<=n;++i) R[i] = (ll)mul*pre[i-1]%p*suf[i+1]%p;
R[0] = power(f[0],p-2);
}
int harmonic(int n){
int s = sqrt(n),top = 0,d,x,invs;
invs = power(s,p-2);
static int g[N],h[N],gd[N],hd[N],st[30];
memset(h,0,sizeof(h));
memset(g,0,sizeof(g));
while(s){
st[++top] = s;
s >>= 1;
}
s = st[1],d = st[top--];
g[0] = h[0] = h[1] = 1;
g[1] = s+1;
while(top--){
lagrange(g,d,d+1,g+d+1);
lagrange(h,d,d+1,h+d+1);
g[d<<1|1] = h[d<<1|1] = 0;
int offset = (ll)d*invs%p;
lagrange(g,d<<1,offset,gd);
lagrange(h,d<<1,offset,hd);
for(reg int i=0;i<=(d<<1);++i){
h[i] = ((ll)h[i]*gd[i]+(ll)hd[i]*g[i])%p;
g[i] = (ll)g[i]*gd[i]%p;
}
d <<= 1;
if(!(st[top+1]&1)) continue;
g[d+1] = value(g,d,d+1);
h[d+1] = value(h,d,d+1);
for(reg int i=0;i<=d+1;++i){
x = s*i;
h[i] = ((ll)h[i]*(x+d+1)+g[i])%p;
g[i] = (ll)g[i]*(x+d+1)%p;
}
++d;
}
getinv(g,s-1,g);
ll res = 0;
for(reg int i=0;i!=s;++i) res += (ll)h[i]*g[i]%p;
s *= s;
if(s==n) return res%p;
for(reg int i=s+1;i<=n;++i) g[i-s] = i;
getinv(g,n-s,g);
for(reg int i=s+1;i<=n;++i) res += g[i-s];
return res%p;
}
int main(){
init(150000);
int T,n;
scanf("%d",&T);
while(T--){
scanf("%d",&n);
printf("%d\n",harmonic(n));
}
return 0;
}
link
实际上就是求这个式子
设
但是这样
这样就有
Code
就先写这么五种吧,然而这种做法可以求的远不止这些,可以试着自己扩展一下。
数学真是有趣啊!(棒读)
不知道写什么了
溜了