高精度多项式乘法的另一种实现
Pulsating_Dust · · 个人记录
传统的高精度多项式乘法一般使用三模NTT或者拆系数FFT实现。
三模NTT的精度为
给出一种大力出奇迹的高精度多项式乘法,其精度能达到
同时速度没有太慢。(任意模数多项式乘法 cin,cout输入输出 900ms)
我们考虑NTT的值域受到模数的限制,因此我们可以考虑使用一个巨大的 NTT 模数,如果使用一个 int128 范围内的NTT模数即可解决问题。
对 int128 的模乘可以通过拆位乘和蒙哥马利约减实现,但是我没有在网上找到 int128 范围内的 NTT 模数,这里给出两个。
不难发现
手玩即可发现它的最小原根是
不难发现
手玩即可发现它的最小原根是
剩下的就是抄一下 NTT 板子,修改一下即可。
给出一个十分糟糕的实现。
#include<iostream>
using i64 = long long;
using u64 = unsigned long long;
using i128 = __int128;
using u128 = __uint128_t;
//对128位数进行的基础支持
namespace u128_s{
constexpr u128 stu128(const char* s){
u128 x = 0;
while (*s) {x = x * 10 + (*s++ - '0');}
return x;
}
u128 stu128(const std::string &s){
return stu128(s.data());
}
std::string to_string(u128 x){
char bbuf[40], *p = bbuf + 40;
do {
*--p = (x % 10) ^ 48, x /= 10;
} while (x > 0);
return std::string(p, bbuf + 40);
}
constexpr void exgcd(i128 a, i128 b, i128 &x, i128 &y) {
if (b == 0){
return ;
}
exgcd(b, a % b, y, x), y -= a / b * x;
}
constexpr u128 inv128(u128 t,u128 m) {
i128 q(m), x(1), y(0);
return exgcd(t, q, x, y), (x %= i128(m)) < 0 ? x + m : x;
}
};
std::istream& operator >> (std::istream& IN, u128 &x){
std::string s;
return IN >> s, x = u128_s::stu128(s), IN;
}
std::ostream& operator << (std::ostream& OUT, u128 x){
return OUT << u128_s::to_string(x);
}
//蒙哥马利空间
namespace Montgo{
//256位无符号整数 由拼接两个128位整数实现
struct u256 {
u128 lo, hi;
constexpr u256() : lo(), hi() {}
constexpr u256(u128 _lo, u128 _hi) : lo(_lo), hi(_hi) {}
//将两个128位数相乘并得到一个256位数
static constexpr u256 mul128(u128 a, u128 b) {
u64 a_hi(a >> 64), a_lo(a);
u64 b_hi(b >> 64), b_lo(b);
u128 p01(u128(a_lo) * b_lo), p12(u128(a_hi) * b_lo + u64(p01 >> 64));
u64 t_hi(p12 >> 64), t_lo(p12);
u128 p23(u128(a_hi) * b_hi + u64((p12 = u128(a_lo) * b_hi + t_lo) >> 64) + t_hi);
return u256(u64(p01) | (p12 << 64), p23);
}
};
//128位蒙哥马利约减器
struct Mont128 {
//模数 (1+k'N) R = 2 ^ 128 % Mod -> R2 = 2 ^ 256 % Mod,
u128 Mod, Inv, R2;
//模数的高64位 模数的低64位 R3 = 2 ^ 384 % Mod
u128 Mod_hi, Mod_lo, R3;
constexpr Mont128(u128 n) : Mod(n), Inv(n), R2((-n % n) << 1), Mod_hi(n >> 64), Mod_lo(u64(n)), R3(){
//牛顿迭代求Inv
for (int i = 0; i < 6; ++i){
Inv *= 2 - n * Inv;
}
for (int i = 0; i < 7; ++i){
R2 = mul_strict(R2, R2);
}
R3 = mul_strict(R2, R2);
}
//蒙哥马利约减 返回值将在[0,2 * Mod)之间
constexpr u128 reduce(u256 x) const {
u128 o(x.lo * Inv);
u64 o_hi(o >> 64), o_lo(o);
return x.hi - (o_hi * Mod_hi + (((o_lo * Mod_hi) + (o_hi * Mod_lo) + u64((o_lo * Mod_lo) >> 64)) >> 64)) + Mod;
}
//进入蒙哥马利数域 *= R
constexpr u128 In(u128 n) const {
return reduce(u256::mul128(n, R2));
}
//进入并进入蒙哥马利数域 *= R *= R
constexpr u128 In_In(u128 n) const {
return reduce(u256::mul128(n, R3));
}
//蒙哥马利约乘 返回值将在[0,2 * Mod) 之间
constexpr u128 mul(u128 a, u128 b) const {
return reduce(u256::mul128(a, b));
}
//严格的蒙哥马利约减 返回值将在[0,Mod) 之间
constexpr u128 reduce_strict(u256 x) const {
u128 o(x.lo * Inv);
u64 o_hi(o >> 64), o_lo(o);
o = x.hi - (o_hi * Mod_hi + (((o_lo * Mod_hi) + (o_hi * Mod_lo) + u64((o_lo * Mod_lo) >> 64)) >> 64));
return i128(o) < 0 ? o + Mod : o;
}
//严格的蒙哥马利约乘 返回值将在[0,Mod) 之间
constexpr u128 mul_strict(u128 a, u128 b) const {
return reduce_strict(u256::mul128(a, b));
}
//离开蒙哥马利数域 /= R
constexpr u128 Out(u128 x) const {
return reduce_strict(u256(x, 0));
}
//获取蒙哥马利数域下的逆元
constexpr u128 inv(u128 t) const {
return mul(u128_s::inv128(t, Mod), R3);
}
};
}
//定义了Z作为域(交换除环).并定义在其上的基本运算.
//Z在蒙哥马利模空间下且值域[0, mod * 2)
namespace field_Z{
//使用的模数 另外注意无法直接表示如此巨大的数字 通过stu128转换了一下
constexpr u128 mod(u128_s::stu128("21267647932558654224715329996419235841"));
constexpr u128 mod2(mod * 2);
constexpr Montgo::Mont128 mont(mod);
template<u128 M>constexpr u128 shrink(u128 x){return x >= M ? x - M : x;}
template<u128 M>constexpr u128 dilate(u128 x){return i128(x) < 0 ? x + M : x;}
//Z类型是一个抽象出来的概念 实际上就是无符号128位整数
using Z = u128;
//蒙哥马利数域下的单位元
constexpr Z one(shrink<mod>(mont.In(1)));
constexpr Z InZ(u128 x) {
return mont.In(x);
}
constexpr Z In_InZ(u128 x) {
return mont.In_In(x);
}
constexpr u128 OutZ(Z x) {
return mont.Out(x);
}
constexpr Z addZ(Z a, Z b) {
return shrink<mod2>(a + b);
}
constexpr Z subZ(Z a, Z b) {
return dilate<mod2>(a - b);
}
constexpr Z mulZ(Z a, Z b) {
return mont.mul(a, b);
}
constexpr Z invZ(Z t) {
return mont.inv(t);
}
constexpr Z divZ(Z a, Z b) {
return mulZ(a, invZ(b));
}
constexpr Z powZ(Z a, u128 b) {
Z r(one);
for(; b; b >>= 1, a = mulZ(a, a)){
if(b & 1){
r = mulZ(r, a);
}
}
return r;
}
constexpr Z mulZ_strict(Z a, Z b) {
return mont.mul_strict(a, b);
}
}
//多项式主体
namespace poly{
//多项式主体::引入对多项式的基础支持
namespace poly_base{
//多项式基础支持::引入所处的域——Z
using namespace field_Z;
//按位向上取整
inline constexpr int bit_ceil(int x){
return 1 << (std::__lg(x - 1) + 1);
}
//多项式基础支持::引入对NTT的支持
namespace poly_NTT_helper{
//mod = 2^65 * 3^1 * 5^1 * 23^1 * 101^1 * 386719^1 * 42779309^1 + 1
constexpr int mp2(65);
//原根为11
constexpr Z _g(InZ(11));
struct P_R_Tab{
Z t[mp2 + 1];
constexpr P_R_Tab(Z G):t(){
t[mp2] = powZ(G, (mod - 1) >> mp2);
for(int i = mp2 - 1; i; --i){
t[i] = mulZ(t[i+1], t[i+1]);
}
}
Z operator [] (int i) const {
return t[i];
}
};
constexpr P_R_Tab __g(_g),__g_Inv(invZ(_g));
int size_W(-1);
Z *Wn(nullptr), *Wn_Inv(nullptr);
void ntt_init_(int lim){
if(lim > size_W){
if(Wn != nullptr){
delete[] Wn;
}
else{
lim = std::max(2, lim);
}
size_W = lim, Wn = new Z[2 * lim], Wn_Inv = Wn + lim;
Wn[0] = Wn[1] = Wn_Inv[0] = Wn_Inv[1] = one;
for(int i = 2, R = 2, i2 = 4; i < lim; i <<= 1, ++R, i2 <<= 1){
Z g_w(__g[R]), g_w_Inv(__g_Inv[R]);
for(int k = i; k < i2; k += 2){
Wn[k] = Wn[k >> 1], Wn[k + 1] = mulZ(Wn[k], g_w);
Wn_Inv[k] = Wn_Inv[k >> 1], Wn_Inv[k + 1] = mulZ(Wn_Inv[k], g_w_Inv);
}
}
}
}
}using namespace poly_NTT_helper;
}using namespace poly_base;
}
namespace poly{
//多项式主体::引入基于转置原理的(DIF式)NTT和(DIT式)INTT
namespace poly_NTT{
//快速数论变换 (DIF)
void NTT(Z* A, int lim){
ntt_init_(lim);
for(int i(lim >> 1), R(lim); i; i >>= 1, R >>= 1){
Z *wn(Wn + i), *a(A + i);
for(int j = 0; j < lim; j += R){
for(int k = 0; k < i; ++k){
Z x(A[j + k]), y(a[j + k]);
a[j + k] = mulZ(x - y + mod2, wn[k]), A[j + k] = addZ(x, y);
}
}
}
}
//快速数论变换.逆 (DIT)
void INTT(Z* A, int lim){
ntt_init_(lim);
for(int i(1), R(2); i < lim; i <<= 1, R <<= 1){
Z *wn(Wn_Inv + i), *a(A + i);
for(int j = 0; j < lim; j += R){
for(int k = 0; k < i; ++k){
Z x(shrink<mod2>(A[j + k])), y(mulZ(a[j + k], wn[k]));
a[j + k] = x - y + mod2, A[j + k] = x + y;
}
}
}
Z invt(In_InZ(mod - ((mod - 1) >> std::__lg(lim))));
for(int i = 0; i < lim; ++i){
A[i] = mulZ_strict(A[i], invt);
}
}
}using namespace poly_NTT;
//点乘
void dot(Z* A, int n, Z* B){
for(int i = 0; i < n; ++i){
A[i] = mulZ(A[i],B[i]);
}
}
//卷积
void Conv(Z* A, int lim, Z* B){
NTT(A, lim), NTT(B, lim), dot(A, lim, B), INTT(A, lim);
}
//自动卷积
void autoConv(Z* A, int n, Z* B, int m){
int lim(bit_ceil(n + m + 1));
std::fill(A + n + 1, A + lim, 0), std::fill(B + m + 1, B + lim, 0), Conv(A,lim,B);
}
}
constexpr int maxn = 1 << 21 | 5;
poly::Z A[maxn], B[maxn];
int main(){
std::ios::sync_with_stdio(false), std::cin.tie(nullptr);
int n, m;
std::cin >> n >> m;
for(int i = 0; i <= n; ++i){
std::cin >> A[i];
}
for(int i = 0; i <= m; ++i){
std::cin >> B[i];
}
poly::autoConv(A, n, B, m);
for(int i = 0; i <= n + m; ++i){
std::cout << A[i] << ' ';
}
return 0;
}
upd:给一个好一点的实现,不过需要C++20(https://www.luogu.com.cn/paste/f6v37gub)
upd:给一个可以用来验证正确性的题目(https://www.luogu.com.cn/problem/U291234)