题解 P1919 【【模板】A*B Problem升级版(FFT快速傅里叶)】

· · 题解

发一下NTT版本的题解

关于NTT,实际上就是利用一些特殊的质数(例如 998244353,这些质数的特征都是可以可以表示 q \times 2 ^ k + 1 的形式)进行膜意义下的运算代替浮点复数运算来保证精度,原理是质数原根和复数单位根在DFT运算中具有相同的性质

具体实现看代码:(其中 998244353 的原根 g = 3

/**************************************************************
    Problem: 2179
    User: infinityedge
    Language: C++
    Result: Accepted
    Time:3424 ms
    Memory:7052 kb
****************************************************************/

#include <iostream>
#include <cstring>
#include <cmath>
#include <cstdio>
#include <algorithm>
#include <queue>
#include <vector>
#include <map>
#include <complex>

#define inf 0x3f3f3f3f
#define N 131072

#define lc k << 1
#define rc k << 1 | 1

using namespace std;

typedef long long ll;
const ll p = 998244353;
typedef complex<double> E;

const double pi = acos(-1);

ll qpow(ll a, ll b, ll p){
    ll res = 1;
    for(; b; b >>= 1, a = a * a % p){
        if(b & 1) res = res * a % p;
    }
    return res;
}

int n;

ll a[N], b[N];
ll tmp[N];

void bit_reverse(int n, ll* r){
    for(int i = 0, j = 0; i < n; i ++){
        if(i > j) swap(r[i], r[j]);
        for(int l = n >> 1; (j ^= l) < l; l >>= 1);
    }
}
ll inv3;
ll pow3[N], powinv3[N];

void pre(){
    inv3 = qpow(3, p - 2, p);
    for(int i = 1; i <= n; i <<= 1) pow3[i] = qpow(3, (p - 1) / i, p);
    for(int i = 1; i <= n; i <<= 1) powinv3[i] = qpow(inv3, (p - 1) / i, p);
}

void NTT(int n, ll* r, ll f){
    bit_reverse(n, r);
    for(int i = 2; i <= n; i <<= 1){
        int m = i >> 1;
        for(int j = 0; j < n; j += i){
            ll w = 1, wn = pow3[i];
            if(f == -1) wn = powinv3[i];
            for(int k = 0; k < m; k ++){
                ll z = r[j + m + k] * w % p;
                r[j + m + k] = (r[j + k] - z + p) % p;
                r[j + k] = (r[j + k] + z) % p;
                w = w * wn % p;
            }
        }
    }
    if(f == -1){
        ll inv = qpow(n, p - 2, p);
        for(int i = 0; i < n; i ++) r[i] = r[i] * inv % p;
    }
}

int c[N]; char ch[N];
int main(){
    scanf("%d", &n); 
    scanf("%s", ch);
    for(int i = 0; i < n; i ++){
        a[i] = ch[n - i - 1] - '0';
    }
    scanf("%s", ch);
    for(int i = 0; i < n; i ++){
        b[i] = ch[n - i - 1] - '0';
    }
    int m = n * 2; n = 1;
    while(n <= m) n <<= 1;
    pre();
    NTT(n, a, 1);
    NTT(n, b, 1);
    for(int i = 0; i < n; i ++) a[i] = a[i] * b[i] % p;
    NTT(n, a, -1);
    for(int i = 0; i < n; i ++) c[i] = a[i];
    for(int i = 0; i < n; i ++){
        if(c[i] >= 10){
            c[i + 1] += c[i] / 10;
            c[i] %= 10;
        }   
    }
    int pp = n - 1;
    while(c[pp] == 0) pp--;
    for(int i = pp; i >= 0; i --){
        printf("%d", c[i]);
    }
    return 0;
}