MTT-fft

· · 个人记录

给出一种使用转置(dit / dif)的4次FFT解决MTT问题的方法

不妨把 dif 看作为普通 fft 后再按位翻转的结果

那么如果要提取其中的反转单位根处的点值并保持位反转的形态

可以先做一次位反转 转换为普通的点值 再翻转 (f + 1, f + lim) 再做一次位反转转换为位反转后的点值

若设与i共轭的位置为f(i)则不难发现

f(i) = 3 * bit_floor(x) - i - 1

例如, 变换长度为16时就会如同这样

0 -> 0

1 -> 1

2 -> 3, 3 -> 2

4 -> 7, 5 -> 6, 6 -> 5, 7 -> 4

8 -> 15, 9 -> 14, 10 -> 13, 11 -> 12, 12 -> 11, 13 -> 10, 14 -> 9, 15 -> 8

证明:

bit_rev(x + 1) -> bit_rev(x) + f(count_one_r(x)) f(x) -> x == lgn ? 1 : 2 ^ (lgn - 1 - x) + 2 ^ (lgn - x)

let lim = 2 ^ lgn

bit_rev(lim - bit_rev(x))

≡ bit_rev(~bit_rev(x) + 1) (mod lim)

≡ bit_rev(bit_rev(~x) + 1) (mod lim)

≡ bit_rev(bit_rev(~x)) + f(~x) (mod lim)

≡ 2 ^ (lgn - 1 - clz(x)) + 2 ^ (lgn - clz(x)) - x + 1 (mod lim)

≡ 2 ^ (lgn - 1 - clz(x)) * 3 - x + 1 (mod lim)

≡ 3 * bit_floor(x) - x - 1 (mod lim)

由此不难写出代码,注意边界要特判一下

#include<iostream>
#include<complex>
#include<vector>

using f64 = double;
using u64 = unsigned long long;
using cpx = std::complex<f64>;
using Poly = std::vector<int>;

constexpr int N = 1 << 18;
constexpr int bceil(int x){return 2 << std::__lg(x - 1);}
namespace f_f_t{
const f64 Pi_2 = acos(-1.0) / 2;

cpx w[N >> 1];

int init_l = 0;

void init(int l){
    if(l > init_l){
        int t = std::__lg(l - 1); l = 1 << t, *w = 1.0, init_l = l << 1;
        for(int i = 0; i < t; ++i){w[1 << i] = std::polar(1.0, Pi_2 / (1 << i));}
        for(int i = 1; i < l; ++i){w[i] = w[i & (i - 1)] * w[i & -i];}
    }
}

void dif(cpx *f, int lim){
    for(int l = lim >> 1, r = lim; l; l >>= 1, r >>= 1){
        for(cpx *j = f, *o = w; j != f + lim; j += r, ++o){
            for(cpx *k = j; k != j + l; ++k){
                cpx x = *k, y = k[l] * *o;
                *k = x + y, k[l] = x - y;
            }
        }
    }
}

void dit(cpx *f, int lim){
    for(int l = 1, r = 2; l < lim; l <<= 1, r <<= 1){
        for(cpx *j = f, *o = w; j != f + lim; j += r, ++o){
            for(cpx *k = j; k != j + l; ++k){
                cpx x = *k, y = k[l];
                *k = x + y, k[l] = (x - y) * std::conj(*o);
            }
        }
    }
}

}
using f_f_t::dif;
using f_f_t::dit;

namespace MTT{

int mod;

//拆系数并进行正变换
void prep(const Poly &a, cpx *f, int l){
    int n = a.size();
    for(int i = 0; i < n; ++i){
        f[i] = cpx(a[i] >> 15, a[i] & 32767);
    }
    std::fill(f + n, f + l, 0), dif(f, l);
}

cpx f0[N], f1[N];

Poly Conv(const Poly &a, const Poly &b){
    int n = a.size(), m = b.size(), u = n + m - 1, l = bceil(u);
    f_f_t::init(l), prep(a, f0, l), prep(b, f1, l);
    //提取系数并逆变换 ps: /= l这步在fx中已经体现了 因为fft是线性变换 因此系数修正的先后并不影响
    f64 fx = 0.5 / l;
    //特判01
    for(int i = 0; i < std::min<int>(2, l); ++i){
        cpx p = f0[i], r = f1[i] * fx;
        f0[i] = (p + std::conj(p)) * r, f1[i] = (std::conj(p) - p) * r;
    }
    for (int k = 2, m = 3; k < l; k <<= 1, m <<= 1){
        for (int i = k, j = (k << 1) - 1; i < m; ++i, --j){
            cpx p = f0[i], q = f0[j], r = f1[i] * fx, s = f1[j] * fx;
            //这部分如同常规的四次FFT
            f0[i] = (p + std::conj(q)) * r, f1[i] = (std::conj(q) - p) * r;
            f0[j] = (q + std::conj(p)) * s, f1[j] = (std::conj(p) - q) * s;
        }
    }
    dit(f0, l), dit(f1, l);
    Poly c(u);
    for(int i = 0; i < u; ++i){
        c[i] = (((u64(f0[i].real() + 0.5) % mod) << 30) + ((u64((f0[i].imag() - f1[i].imag()) + 0.5) % mod) << 15) + u64(f1[i].real() + 0.5)) % mod;
    }
    return c;
}
}

void solve(){
    int n, m;
    std::cin >> n >> m >> MTT::mod;
    Poly F(n + 1), G(m + 1);
    for(auto &x : F){
        std::cin >> x;
    }
    for(auto &x : G){
        std::cin >> x;
    }
    for(auto x : MTT::Conv(F, G)){
        std::cout << x << ' ';
    }
}

int main(){
    std::cin.tie(nullptr) -> sync_with_stdio(false);
    solve();
    return 0;
}