题解:P4721 【模板】分治 FFT

· · 题解

题目传送门

前置知识:FFT / NTT。

给定序列 g_{1\dots n - 1},求序列 f_{0\dots n - 1}。\ 其中 f_i=\sum_{j=1}^if_{i-j}g_j,边界为 f_0=1。\ 答案对 998244353 取模。

思路

观察题目,我们发现式子是一个卷积的形式。

但是 f 的每一项都是由前 O(n)递推来的。所以此时直接 FFT / NTT 会达到恐怖的 O(n^2 \log n)(暴力是 O(n^2)),所以 O(n \log n) 是行不通的。

那么,O(n \log^2 n) 呢?

我们采用——分治

对于一个我们要求的系数区间 [l, r] (就是 fl 次方项到 r 次方项的系数),具体是这样的:

  1. 先求出左半边 [l, mid - 1]f 对自己的贡献,即只使用 [l, mid - 1] 中的元素与 g 相乘,且只需要答案的 [l, mid - 1] 部分,我们发现这个可以递归。
  2. 求出左半边对右半边的贡献,即用 [l, mid - 1] 中的元素乘 g,将乘积的 [mid, r] 这几项加给 f 的对应次方系数。
  3. 递归右半边。

我们发现操作中的 g 只有前 r - l + 1 项有用(后面的我们不需要),所以时间复杂度为递归乘 FFT / NTT 为 O(n \log^2 n)

注:最好补齐项数到 2 的次幂,这样每个递归区间长度都也是 2 的次幂,便于乘法。

Code

#include <bits/stdc++.h>
using namespace std;

#define int long long
constexpr int maxn = (1 << 21) + 10, modd = 998244353, g1 = 3, gi = 332748118;

int n, G[maxn], F[maxn], rev[maxn];

int ksm(int a, int b) { // 快速幂
    int ress = 1;
    while (b) {
        if (b & 1) ress = ress * a % modd;
        a = a * a % modd;
        b >>= 1;
    }
    return ress;
}

void NTT(int *f, int len, int k) { // NTT 与模板毫无需要改动的地方(FFT 当然也可以)
    for (int i = 0; i < len; i++) {
        if (rev[i] > i) swap(f[i], f[rev[i]]); 
    }
    for (int d = 1, g, dg, nowa, nowb; d < len; d <<= 1) {
        dg = ksm((k == 1 ? g1 : gi), (modd - 1) / (d << 1));
        for (int i = 0; i < len; i += (d << 1)) {
            g = 1;
            for (int j = 0; j < d; j++) {
                nowa = f[i + j];
                nowb = g * f[i + j + d];
                f[i + j] = (nowa + nowb) % modd;
                f[i + j + d] = (nowa - nowb) % modd;
                g = g * dg % modd;
            }
        }
    }
    return ;
}

void solve(int l, int r) { // 核心代码
    if (l == r) { // 边界
        return ;
    }
    int mid = (r + l + 1) >> 1, nowl = (r - l + 1) << 1; // 这个 mid 把区间化成了 [l, mid - 1], [mid, r]
    solve(l, mid - 1); // 递归左边
    rev[0] = 0;
    rev[1] = (nowl >> 1);
    for (int i = 4, k = (nowl >> 2); i <= nowl; i <<= 1, k >>= 1) {
        for (int j = (i >> 1); j < i; j++) {
            rev[j] = rev[j - (i >> 1)] + k;
        }
    } // 这里 FFT / NTT 的长度会变化,所以需要动态维护 rev
    int nowf[nowl], nowg[nowl];
    for (int i = l; i < mid; i++) {
        nowf[i - l] = F[i];
    }
    for (int i = mid - l; i < nowl; i++) { // 注意不是全局变量的数组要赋初值
        nowf[i] = 0;
    }
    for (int i = 0; i <= r - l; i++) {
        nowg[i] = G[i];
    }
    for (int i = r - l + 1; i < nowl; i++) {
        nowg[i] = 0;
    }
    NTT(nowf, nowl, -1);
    NTT(nowg, nowl, -1);
    for (int i = 0; i < nowl; i++) {
        nowf[i] = (nowf[i] * nowg[i]) % modd;
    }
    NTT(nowf, nowl, 1);
    for (int i = 0; i < nowl; i++) {
        nowf[i] *= ksm(nowl, modd - 2);
        nowf[i] = (nowf[i] % modd + modd) % modd;
    }
    for (int i = mid; i <= r; i++) { // 加到右边
        F[i] += nowf[i - l];
        F[i] %= modd;
    }
    solve(mid, r); // 递归右边
    return ;
}

signed main() {
    ios::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr);
    cin >> n;
    n--;// 题面只需要求到 n - 1
    F[0] = 1;
    for (int i = 1; i <= n; i++) {
        cin >> G[i];
    }
    int maxl = 1;
    while (maxl <= n) maxl <<= 1; // 补全成 2 的次幂
    solve(0, maxl - 1);

    cout << 1 << ' ';
    for (int i = 1; i <= n; i++) {
        cout << F[i] << ' ';
    }

    return 0; // 因为 long long, 我调了 4h
}

十年 OI 一场空,不开 long long 见祖宗。