高精度多项式乘法的另一种实现

· · 个人记录

传统的高精度多项式乘法一般使用三模NTT或者拆系数FFT实现。

三模NTT的精度为 10 ^ {26} 量级,拆系数FFT的精度大概也为这个量级(除非你打算用能使常数翻10倍的__float128),而且两者均受到变换长度的制约。

给出一种大力出奇迹的高精度多项式乘法,其精度能达到 10 ^ {37} 级别且不会受到变换长度的制约。

同时速度没有太慢。(任意模数多项式乘法 cin,cout输入输出 900ms)

我们考虑NTT的值域受到模数的限制,因此我们可以考虑使用一个巨大的 NTT 模数,如果使用一个 int128 范围内的NTT模数即可解决问题。

对 int128 的模乘可以通过拆位乘和蒙哥马利约减实现,但是我没有在网上找到 int128 范围内的 NTT 模数,这里给出两个。

不难发现 21267647932558654224715329996419235841 = 2^{65} \times 3^1 \times 5^1 \times 23^1 \times 101^1 \times 386719^1 \times 42779309^1 + 1 且是一个质数,它略大于 2 ^ {124}

手玩即可发现它的最小原根是 11

不难发现 85070591730234613320192969686023929857 = 2^{65} \times 138143^1 \times 16691710830181^1 + 1 且是一个质数,它略小于 2 ^ {126}

手玩即可发现它的最小原根是 3

剩下的就是抄一下 NTT 板子,修改一下即可。

给出一个十分糟糕的实现。

#include<iostream>

using i64 = long long;
using u64 = unsigned long long;
using i128 = __int128;
using u128 = __uint128_t;

//对128位数进行的基础支持
namespace u128_s{
    constexpr u128 stu128(const char* s){
        u128 x = 0;
        while (*s) {x = x * 10 + (*s++ - '0');}
        return x;
    }
    u128 stu128(const std::string &s){
        return stu128(s.data());
    }
    std::string to_string(u128 x){
        char bbuf[40], *p = bbuf + 40;
        do {
            *--p = (x % 10) ^ 48, x /= 10;
        } while (x > 0);
        return std::string(p, bbuf + 40);
    }
    constexpr void exgcd(i128 a, i128 b, i128 &x, i128 &y) {
        if (b == 0){
            return ;
        }
        exgcd(b, a % b, y, x), y -= a / b * x;
    }
    constexpr u128 inv128(u128 t,u128 m) {
        i128 q(m), x(1), y(0);
        return exgcd(t, q, x, y), (x %= i128(m)) < 0 ? x + m : x;
    }
};

std::istream& operator >> (std::istream& IN, u128 &x){
    std::string s;
    return IN >> s, x = u128_s::stu128(s), IN;
}

std::ostream& operator << (std::ostream& OUT, u128 x){
    return OUT << u128_s::to_string(x);
}

//蒙哥马利空间
namespace Montgo{
    //256位无符号整数 由拼接两个128位整数实现
    struct u256 {
        u128 lo, hi;
        constexpr u256() : lo(), hi() {}
        constexpr u256(u128 _lo, u128 _hi) : lo(_lo), hi(_hi) {}
        //将两个128位数相乘并得到一个256位数
        static constexpr u256 mul128(u128 a, u128 b) {
            u64 a_hi(a >> 64), a_lo(a);
            u64 b_hi(b >> 64), b_lo(b);
            u128 p01(u128(a_lo) * b_lo), p12(u128(a_hi) * b_lo + u64(p01 >> 64));
            u64 t_hi(p12 >> 64), t_lo(p12);
            u128 p23(u128(a_hi) * b_hi + u64((p12 = u128(a_lo) * b_hi + t_lo) >> 64) + t_hi);
            return u256(u64(p01) | (p12 << 64), p23);
        }
    };
    //128位蒙哥马利约减器
    struct Mont128 {
        //模数 (1+k'N) R = 2 ^ 128 % Mod -> R2 = 2 ^ 256 % Mod, 
        u128 Mod, Inv, R2;

        //模数的高64位 模数的低64位 R3 = 2 ^ 384 % Mod
        u128 Mod_hi, Mod_lo, R3;

        constexpr Mont128(u128 n) : Mod(n), Inv(n), R2((-n % n) << 1), Mod_hi(n >> 64), Mod_lo(u64(n)), R3(){
            //牛顿迭代求Inv
            for (int i = 0; i < 6; ++i){
                Inv *= 2 - n * Inv;
            }
            for (int i = 0; i < 7; ++i){
                R2 = mul_strict(R2, R2);
            }
            R3 = mul_strict(R2, R2);
        }
        //蒙哥马利约减 返回值将在[0,2 * Mod)之间
        constexpr u128 reduce(u256 x) const {
            u128 o(x.lo * Inv);
            u64 o_hi(o >> 64), o_lo(o);
            return x.hi - (o_hi * Mod_hi + (((o_lo * Mod_hi) + (o_hi * Mod_lo) + u64((o_lo * Mod_lo) >> 64)) >> 64)) + Mod;
        }
        //进入蒙哥马利数域 *= R
        constexpr u128 In(u128 n) const {
            return reduce(u256::mul128(n, R2));
        }
        //进入并进入蒙哥马利数域 *= R *= R
        constexpr u128 In_In(u128 n) const {
            return reduce(u256::mul128(n, R3));
        }
        //蒙哥马利约乘 返回值将在[0,2 * Mod) 之间
        constexpr u128 mul(u128 a, u128 b) const {
            return reduce(u256::mul128(a, b));
        }
        //严格的蒙哥马利约减 返回值将在[0,Mod) 之间
        constexpr u128 reduce_strict(u256 x) const {
            u128 o(x.lo * Inv);
            u64 o_hi(o >> 64), o_lo(o);
            o = x.hi - (o_hi * Mod_hi + (((o_lo * Mod_hi) + (o_hi * Mod_lo) + u64((o_lo * Mod_lo) >> 64)) >> 64));
            return i128(o) < 0 ? o + Mod : o;
        }
        //严格的蒙哥马利约乘 返回值将在[0,Mod) 之间
        constexpr u128 mul_strict(u128 a, u128 b) const {
            return reduce_strict(u256::mul128(a, b));
        }
        //离开蒙哥马利数域 /= R 
        constexpr u128 Out(u128 x) const {
            return reduce_strict(u256(x, 0));
        }
        //获取蒙哥马利数域下的逆元
        constexpr u128 inv(u128 t) const {
            return mul(u128_s::inv128(t, Mod), R3);
        }
    };
}

//定义了Z作为域(交换除环).并定义在其上的基本运算.
//Z在蒙哥马利模空间下且值域[0, mod * 2)
namespace field_Z{
    //使用的模数 另外注意无法直接表示如此巨大的数字 通过stu128转换了一下
    constexpr u128 mod(u128_s::stu128("21267647932558654224715329996419235841"));
    constexpr u128 mod2(mod * 2);
    constexpr Montgo::Mont128 mont(mod);
    template<u128 M>constexpr u128 shrink(u128 x){return x >= M ? x - M : x;}
    template<u128 M>constexpr u128 dilate(u128 x){return i128(x) < 0 ? x + M : x;}
    //Z类型是一个抽象出来的概念 实际上就是无符号128位整数
    using Z = u128;
    //蒙哥马利数域下的单位元
    constexpr Z one(shrink<mod>(mont.In(1)));

    constexpr Z InZ(u128 x) {
        return mont.In(x);
    }
    constexpr Z In_InZ(u128 x) {
        return mont.In_In(x);
    }
    constexpr u128 OutZ(Z x) {
        return mont.Out(x);
    }
    constexpr Z addZ(Z a, Z b) {
        return shrink<mod2>(a + b);
    }
    constexpr Z subZ(Z a, Z b) {
        return dilate<mod2>(a - b);
    }
    constexpr Z mulZ(Z a, Z b) {
        return mont.mul(a, b);
    }
    constexpr Z invZ(Z t) {
        return mont.inv(t);
    }
    constexpr Z divZ(Z a, Z b) {
        return mulZ(a, invZ(b));
    }
    constexpr Z powZ(Z a, u128 b) {
        Z r(one);
        for(; b; b >>= 1, a = mulZ(a, a)){
            if(b & 1){
                r = mulZ(r, a);
            }
        }
        return r;
    }

    constexpr Z mulZ_strict(Z a, Z b) {
        return mont.mul_strict(a, b);
    }
}

//多项式主体
namespace poly{
    //多项式主体::引入对多项式的基础支持
    namespace poly_base{
        //多项式基础支持::引入所处的域——Z
        using namespace field_Z;
        //按位向上取整
        inline constexpr int bit_ceil(int x){
            return 1 << (std::__lg(x - 1) + 1);
        }
        //多项式基础支持::引入对NTT的支持
        namespace poly_NTT_helper{
            //mod = 2^65 * 3^1 * 5^1 * 23^1 * 101^1 * 386719^1 * 42779309^1 + 1
            constexpr int mp2(65);
            //原根为11
            constexpr Z _g(InZ(11));
            struct P_R_Tab{
                Z t[mp2 + 1];
                constexpr P_R_Tab(Z G):t(){
                    t[mp2] = powZ(G, (mod - 1) >> mp2);
                    for(int i = mp2 - 1; i; --i){
                        t[i] = mulZ(t[i+1], t[i+1]);
                    }
                }
                Z operator [] (int i) const {
                    return t[i];
                }
            };
            constexpr P_R_Tab __g(_g),__g_Inv(invZ(_g));
            int size_W(-1);
            Z *Wn(nullptr), *Wn_Inv(nullptr);
            void ntt_init_(int lim){
                if(lim > size_W){
                    if(Wn != nullptr){
                        delete[] Wn;
                    }
                    else{
                        lim = std::max(2, lim);
                    }
                    size_W = lim, Wn = new Z[2 * lim], Wn_Inv = Wn + lim;
                    Wn[0] = Wn[1] = Wn_Inv[0] = Wn_Inv[1] = one;
                    for(int i = 2, R = 2, i2 = 4; i < lim; i <<= 1, ++R, i2 <<= 1){
                        Z g_w(__g[R]), g_w_Inv(__g_Inv[R]);
                        for(int k = i; k < i2; k += 2){
                            Wn[k] = Wn[k >> 1], Wn[k + 1] = mulZ(Wn[k], g_w);
                            Wn_Inv[k] = Wn_Inv[k >> 1], Wn_Inv[k + 1] = mulZ(Wn_Inv[k], g_w_Inv);
                        }
                    }
                }
            }
        }using namespace poly_NTT_helper;

    }using namespace poly_base;

}

namespace poly{
    //多项式主体::引入基于转置原理的(DIF式)NTT和(DIT式)INTT
    namespace poly_NTT{
        //快速数论变换 (DIF)
        void NTT(Z* A, int lim){
            ntt_init_(lim);
            for(int i(lim >> 1), R(lim); i; i >>= 1, R >>= 1){
                Z *wn(Wn + i), *a(A + i);
                for(int j = 0; j < lim; j += R){
                    for(int k = 0; k < i; ++k){
                        Z x(A[j + k]), y(a[j + k]);
                        a[j + k] = mulZ(x - y + mod2, wn[k]), A[j + k] = addZ(x, y);
                    }
                }
            }
        }
        //快速数论变换.逆 (DIT)
        void INTT(Z* A, int lim){
            ntt_init_(lim);
            for(int i(1), R(2); i < lim; i <<= 1, R <<= 1){
                Z *wn(Wn_Inv + i), *a(A + i);
                for(int j = 0; j < lim; j += R){
                    for(int k = 0; k < i; ++k){
                        Z x(shrink<mod2>(A[j + k])), y(mulZ(a[j + k], wn[k]));
                        a[j + k] = x - y + mod2, A[j + k] = x + y;
                    }
                }
            }
            Z invt(In_InZ(mod - ((mod - 1) >> std::__lg(lim))));
            for(int i = 0; i < lim; ++i){
                A[i] = mulZ_strict(A[i], invt);
            }
        }
    }using namespace poly_NTT;

    //点乘
    void dot(Z* A, int n, Z* B){
        for(int i = 0; i < n; ++i){
            A[i] = mulZ(A[i],B[i]);
        }
    }

    //卷积
    void Conv(Z* A, int lim, Z* B){
        NTT(A, lim), NTT(B, lim), dot(A, lim, B), INTT(A, lim);
    }

    //自动卷积
    void autoConv(Z* A, int n, Z* B, int m){
        int lim(bit_ceil(n + m + 1));
        std::fill(A + n + 1, A + lim, 0), std::fill(B + m + 1, B + lim, 0), Conv(A,lim,B);
    }
}

constexpr int maxn = 1 << 21 | 5;

poly::Z A[maxn], B[maxn];

int main(){
    std::ios::sync_with_stdio(false), std::cin.tie(nullptr);
    int n, m;
    std::cin >> n >> m;
    for(int i = 0; i <= n; ++i){
        std::cin >> A[i];
    }
    for(int i = 0; i <= m; ++i){
        std::cin >> B[i]; 
    }
    poly::autoConv(A, n, B, m);
    for(int i = 0; i <= n + m; ++i){
        std::cout << A[i] << ' ';
    }
    return 0;
}

upd:给一个好一点的实现,不过需要C++20(https://www.luogu.com.cn/paste/f6v37gub)

upd:给一个可以用来验证正确性的题目(https://www.luogu.com.cn/problem/U291234)