【模板】多项式三角函数

· · 个人记录

给定 F(x),求出 G(x),使得

G(x)\equiv \sin{F(x)} \pmod {x^n}

G(x)\equiv \cos{F(x)} \pmod {x^n}

系数对 998244353 取模。

Solution

Euler 定理,我们知道

e^{\text{i}x}=\cos{x}+\text{i}\sin x

x=-x,我们得到

e^{-\text{i}x}=\cos{x}-\text{i}\sin x

[ 其中已利用 \cos {(-x)}= \cos{x}\sin{(-x)}=-\sin{x} ]

联立以上两式,不难解得

\boxed{\sin{x}=\frac{e^{\text{i}x}-e^{-\text{i}x}}{2\text{i}}} \boxed{\cos{x}=\frac{e^{\text{i}x}+e^{-\text{i}x}}{2}}

于是你只需要用一个 多项式 \exp 的板子。

喂!等等!我们先前的板子都是用的 FNTT,可是这里出现了虚数单位 \text{i} 诶!

我们是在 \bmod {998244353} 意义下运算的。显然有

\text{i}^2\equiv-1 \pmod {998244353}

\text{i}^2\equiv \color{red}{998244352} \pmod {998244353}

事实上只需要取

\boxed{\text{i}\equiv g^{\frac{p-1}{4}}}

\text{i}\equiv86583718 \pmod {998244353}

即可。其中 g 为模 p 原根。

于是就又水过了一道紫题。

复杂度 \Theta{\left(N \log{N}\right)}。耗时最少 2.12\text{s}。(Record)

#include <bits/stdc++.h>
using namespace std;
const int MAXN=1e5+10, p=998244353, G=3;
#define int long long
#define reg register
static int rev[MAXN*4];
static int f[MAXN*4], g[MAXN*4], h[MAXN*4];
int n, m;
inline int read() { int x=0,f=1;char ch=getchar();
    while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
    while (ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();} return x*f;
}
inline void write(int x) { if (x<0) { putchar('-'), write(-x); return; }
    if (x>9) {write(x/10);} putchar(x%10+'0');
}
inline int modu(int x) { if (x<p) return x;
    return x-x/p*p;
}
inline int modu(int x, int y) { x+=y; if (x>=p) x-=p; return x; }
int qpow(int a, int b, int p) { modu(a); int res=1;
    while (b) {if (b&1) res=modu(res*a);a=modu(a*a);b>>=1;} return res;
}
inline void fntt(int f[], int n, int sgn=1) { int bit=__lg(n);
    for (reg int i=1; i<n; ++i) { rev[i]=(rev[i>>1]>>1) | ((i&1))<<(bit-1);
        if (i<rev[i]) swap(f[i], f[rev[i]]); }
    for (reg int l=1, t=1; l<n; l<<=1, ++t) { 
        int step=qpow(G, ((p-1)>>t)*sgn+p-1, p);
        for (reg int i=0; i<n; i+=l<<1) {
            for (reg int k=i, cur=1; k<i+l; ++k, cur=cur*step%p) {
                int g=f[k], h=modu(f[k+l]*cur);
                f[k]=modu(g,h), f[k+l]=modu(g,p-h);
            }
        }
    }
    if (sgn==-1) {
        int inv=qpow(n,p-2,p);
        for (reg int i=0; i<n; ++i)
            f[i]=modu(f[i]*inv);
    }
}
#define inv(a) qpow(a,p-2,p)
inline void solve_inv(int n, int*f, int *g) {
    if (n==1) {
        g[0]=inv(f[0]); return;
    }
    solve_inv((n+1)>>1,f,g);
    int N=1;
    while (N < (n<<1)) N<<=1;
    for (reg int i=0; i<n; ++i) h[i]=f[i];
    for (int i=n; i<N; ++i) h[i]=0;
    fntt(h,N), fntt(g,N);
    for (reg int i=0; i<N; ++i) {
        g[i]=modu(modu(2-modu(g[i]*h[i])+p)*g[i]); // G'*(2-F*G')
    }
    fntt(g,N,-1);
    for (reg int i=n; i<=N; ++i) g[i]=0;
}
inline void get_inv(int n, int* f, int *g) { // 求f的逆元,结果存在g中
    solve_inv(n,f,g);
    for (reg int i=0; i<n; ++i) {
        g[i]=modu(g[i]+p);
    }
}
inline void deriv(int n, int *f, int *g) { // 对f求导,结果存在g中
    for (reg int i=0; i<n; ++i) {
        g[i]=(i+1)*f[i+1]%p;
    }
}
inline void inte(int n, int *f, int *g) { // 求f的积分,结果存在g中
    g[0]=0;
    for (int i=1; i<=n; ++i) {
        g[i]=modu(f[i-1]*inv(i));
    }
}
inline void get_ln(int n, int *f, int *g) { // 求f的ln,结果存在g中
    static int df[MAXN*4], invf[MAXN*4], dg[MAXN*4];
    int N=1<<(__lg(n+n+1+1)+1);
    for (reg int i=0; i<N; ++i) df[i]=invf[i]=dg[i]=0; // 记得清空
    deriv(n,f,df);
    get_inv(n,f,invf);
    // F' * F^(-1)
    fntt(df,N); fntt(invf,N);
    for (reg int i=0; i<N; ++i)
        dg[i]=modu(df[i]*invf[i]);
    fntt(dg,N,-1);
    inte(n,dg,g);
}
static int ln_g[MAXN*4];
inline void get_exp(int n, int *f, int *g) { // 求f的exp,结果存在g中           
    if (n==1) {g[0]=1; return;}
    get_exp((n+1)>>1, f, g);
    for (reg int i=0; i<n; ++i) ln_g[i]=0;
    get_ln(n, g, ln_g);
    int N=1; while (N<n<<1) N<<=1;
    for (reg int i=n; i<N; ++i) g[i]=0;
    for (reg int i=0; i<n; ++i) ln_g[i]=modu(f[i],p-ln_g[i]);
    ln_g[0]++;
    fntt(ln_g, N, 1); fntt(g, N, 1);
    for (reg int i=0; i<N; ++i) g[i]=modu(g[i]*ln_g[i]);
    fntt(g, N, -1);
}   
const int I=86583718, inv_i=inv(I);
inline void get_sin(int n, int *f, int *g) {
    static int tmp[MAXN*4], exp_f[MAXN*4], inv_exp[MAXN*4];
    for (int i=0; i<n; ++i) tmp[i]=f[i]*I%p,
        exp_f[i]=0, inv_exp[i]=0;
    get_exp(n,tmp,exp_f); get_inv(n,exp_f,inv_exp);
    int inv_2i=inv(I<<1);
    for (int i=0; i<n; ++i) g[i]=((exp_f[i]-inv_exp[i]+p)%p*inv_2i%p+p)%p;
}
inline void get_cos(int n, int *f, int *g) {
    static int tmp[MAXN*4], exp_f[MAXN*4], inv_exp[MAXN*4];
    for (int i=0; i<n; ++i) tmp[i]=f[i]*I%p,
        exp_f[i]=0, inv_exp[i]=0;
    get_exp(n,tmp,exp_f); get_inv(n,exp_f,inv_exp);
    int inv_2=inv(2);
    for (int i=0; i<n; ++i) g[i]=((exp_f[i]+inv_exp[i]+p)%p*inv_2%p+p)%p;
}
int type;
signed main() {
    n=read(); type=read();
    for (reg int i=0; i<n; ++i)
        f[i]=read();
    if (type==0) get_sin(n,f,g);
    else get_cos(n,f,g);
    for (reg int i=0; i<n; ++i) {
        write(g[i]); putchar(' ');
    }
    return 0;
}