压位高精度模板(带FFT)

· · 个人记录

\rule{120pt}{30pt}\kern{-85pt}\color{white}\raisebox{12pt}{\sf 模板集合}

更新

2023/12/31 更新:

2024/01/01 更新:

2024/02/05 更新:

2024/02/23 更新:

保证

已通过所有高精度模板题(包括 FFT 板子)。

python2 \times 10^5 的数据下进行过对拍,结果正确,除取余外遥遥领先于 python

详情见 性能测试 部分。

实现细节

对于高精度运算 op(a, b),以下均用 n 代表 a 的位数,m 代表 b 的位数。

压位位数 w = 8,可修改。

由于 FFT 做压位的上限太低,且朴素乘法有 \dfrac{1}{w^2} 的常数,故 nm65536 以内的数据采用朴素高精乘法。除法使用二分。

复杂度:

操作 函数名 时间 空间
字符串初始化 / O( \dfrac{n}{w} ) O( \dfrac{n}{w} )
整型初始化 / O( \dfrac{\log n}{w} ) O( \dfrac{\log n}{w} )
复制初始化 / O( \dfrac{n}{w} ) O( \dfrac{n}{w} )
清空(置 0) clear O( \dfrac{n}{w} ) O(1)
预先分配空间 reserve O(k) O(k)
读入 / O(n) O( n)
输出 / O(n) O(n)
取绝对值 abs O( \dfrac{n}{w} ) O( \dfrac{n}{w} )
取反 / O( \dfrac{n}{w} ) O( \dfrac{n}{w} )
判断是否为 0 iszero O(1) O(1)
2 mul2 O( \dfrac{n}{w} ) O( \dfrac{n}{w} )
2 div2 O( \dfrac{n}{w} ) O( \dfrac{n}{w} )
比较运算符 / O( \dfrac{n}{w} ) O(1)
/ O( \dfrac{n+m}{w} ) O( \dfrac{\max (n, m)}{w} )
/ O( \dfrac{n+m}{w} ) O( \dfrac{\max (n, m)}{w} )
自增 / O( \dfrac{n}{w} ) O( \dfrac{n}{w} )
自减 / O( \dfrac{n}{w} ) O( \dfrac{n}{w} )
乘高精(朴素) / O( \dfrac{nm}{w^2} ) O( \dfrac{n+m}{w} )
乘高精(FFT) / O((n+m)\log(n+m)) O(n+m)
乘单精 / O(\dfrac{n}{w}) O(\dfrac{n}{w})
除高精(朴素) / O(\dfrac{n^3+m^3}{w^2}) O( \dfrac{n+m}{w} )
除高精(FFT) / O((n^2+m^2) \log (n+m)) O( n+m )
除单精 / O(\dfrac{n}{w}) O(\dfrac{n}{w})
取余(朴素) / O(\dfrac{n^3+m^3}{w^2}) O( \dfrac{n+m}{w} )
取余(FFT) / O((n^2+m^2) \log (n+m)) O( n+m )
乘方(朴素) pow O( \dfrac{nm}{w^2} \log m) O( \dfrac{n+m}{w} )
乘方(FFT) pow O( (n+m)\log(n+m))) O(n+m)
进制乘方 pow_base O(n \log n) O(n)
2 mod2 O(1) O(1)
平方根(朴素) sqrt O(\dfrac{n^3}{w} \log {n}) O(\dfrac{n}{w})
平方根(FFT) sqrt O(n^2 \log {n}) O(n)

性能测试

统计图(1000 次测试取平均值):

附上性能测试代码:

//对拍&测试
#include <bits/stdc++.h>
#include <windows.h>
using namespace std;

random_device rd;
mt19937 ran(rd());

const int DATA = 2e4;

void gen() {
    int n = ran() % DATA;
    putchar((ran() % 9 + 1) ^ 48); 
    for (int i = 1; i < n; i++) {
        putchar((ran() % 10) ^ 48);
    }
    putchar(' ');
    putchar((ran() % 9 + 1) ^ 48); 
    for (int i = 1; i < n; i++) {
        putchar((ran() % 10) ^ 48);
    }
}

double aaa, bbb;
#define _begin aaa = clock()
#define _end bbb = clock()
#define _time (bbb - aaa)

int main() {
    freopen("result/result-2e4-div.csv", "w", stderr);
    cerr << "State,Python-time,BigInteger-time" << endl;
    int T = 1000;
    while (T--) {
        freopen("data.in", "w", stdout);
        gen();
        freopen("details.out", "w", stdout);

        _begin; system("python std.py"); _end;
        double py = _time;
        _begin; system("BigInteger.exe < data.in > biginteger.out"); _end;
        double big = _time;

        if (system("fc py.out biginteger.out")) {
            cerr << "WA,0.0,0.0" << endl;
            break;
        }

        cerr << "AC," << py << "," << big << endl;
    }
    return 0;
}
# 画图
import matplotlib.pyplot as plt

plt.rcParams['font.sans-serif'] = ['SimHei']

NAME = ['add', 'sub', 'div', 'mul', 'mod']
FILE = 'result-{data}-{name}.csv'
DATA = ['2e3', '2e4', '2e5']

from collections import defaultdict
import numpy as np
k = defaultdict(list)
for n in NAME:
    for d in DATA:
        f = FILE.replace('{data}', d).replace('{name}', n)
        with open(f, 'r', encoding='utf-8') as fp:
            c = fp.read()
        a = c.split('\n')[1001]
        x, y = map(float, a.split(',')[1:])
        k[n].append([x, y])

K = 'mod'

xtick = ['2e3', '2e4', '2e5']
py = np.array([i[0] for i in k[K]])
big = np.array([i[1] for i in k[K]])
print(x, py, big)

n_groups = 3
index = np.arange(n_groups)

bar_width = 0.35

r1 = plt.bar(index, big, bar_width, color='b', label='BigInteger')
r2 = plt.bar(index + bar_width, py, bar_width, color='g', label='Python')

plt.title('高精度取余の性能统计图')
plt.xlabel('数据范围')
plt.ylabel('运行时间(ms)')
plt.xticks(index + bar_width, xtick)

plt.legend(loc='upper left')

plt.show()

以及高精板子题的测试结果作为参考(测试点平均运行时间/最大运行时间):

题目 平均运行时间/ms 最大运行时间/ms
P1601 (加法) 3.4 4
P2142 (减法) 3.6 4
P1303 (朴素乘法) 3.0 3
P1919 (FFT) 915.0 948
P1480 (高精除单精) 3.0 3
P2005 (高精除高精) 194.5 727
P1932 (全套) 75.7 481

大模板

400 行,写了半个月。

#ifndef BIGINTEGER_H
#define BIGINTEGER_H

#include <vector>
#include <string>
#include <climits>
#include <iomanip>
#include <sstream>

#ifndef BIGINTEGER_DISABLE_FFTMUL

#include <cmath>
namespace BigInteger_FFT {
    const double PI = std::acos(-1.0);

    class complex {
    public:
        double real, imag;

        complex(double x = 0.0, double y = 0.0) : real(x), imag(y) {}

        complex operator + (const complex& other) const {
            return complex(real + other.real, imag + other.imag);
        }
        complex operator - (const complex& other) const {
            return complex(real - other.real, imag - other.imag);
        }
        complex operator * (const complex& other) const {
            return complex(real * other.real - imag * other.imag, real * other.imag + other.real * imag);
        }
    };

    std::vector<complex> f, g;

    void fft(std::vector<complex>& a, int len, double type) {
        if (len == 1) return;
        std::vector<complex> fl((len >> 1) + 1), fr((len >> 1) + 1);
        for (int i = 0; i <= len; i++) {
            if (i & 1) fr[i >> 1] = a[i];
            else fl[i >> 1] = a[i];
        }
        fft(fl, len >> 1, type), fft(fr, len >> 1, type);
        complex w1 = complex(cos(2.0 * PI / len), sin(2.0 * PI / len) * type), w = complex(1.0, 0.0);
        for (int i = 0; i < (len >> 1); i++, w = w * w1) {
            a[i] = fl[i] + fr[i] * w;
            a[i + (len >> 1)] = fl[i] - fr[i] * w;
        } 
    }

    std::vector<int> mul(const std::vector<int>& a, const std::vector<int>& b) {
        int n = a.size() - 1, m = b.size() - 1;
        int len = 1;
        while (len <= n + m) len <<= 1;
        f.clear(), f.resize(len + 1);
        g.clear(), g.resize(len + 1);
        for (int i = 0; i <= n; i++) f[i] = complex(a[i], 0.0);
        for (int i = 0; i <= m; i++) g[i] = complex(b[i], 0.0);

        fft(f, len, 1.0), fft(g, len, 1.0);
        for (int i = 0; i <= len; i++) f[i] = f[i] * g[i];
        fft(f, len, -1.0);

        std::vector<int> res;
        res.reserve(n + m + 1);
        for (int i = 0; i <= n + m; i++) res.emplace_back(f[i].real / len + 0.5);
        return res;
    }
} // namespace BigInteger_FFT

#endif // BIGINTEGER_DISABLE_FFTMUL

class BigInteger {
private:
    using digit_t = long long;

    static const int WIDTH = 8;
    static const digit_t BASE = 1e8;
    static const long long FFT_LIMIT = 65536;
    //static const long long NEWTON_LIMIT = 65536;

    std::vector<digit_t> digits;
    bool sign;

    int compare(const BigInteger& x) const {
        if (sign && !x.sign) return 1;
        if (!sign && x.sign) return -1;
        if (iszero() && x.iszero()) return 0;

        if (iszero()) return x.sign ^ 1; 
        if (x.iszero()) return sign;

        int sgn = (sign && x.sign ? 1 : -1);
        if (digits.size() > x.digits.size()) return sgn;
        if (digits.size() < x.digits.size()) return sgn ^ 1;

        for (int i = digits.size() - 1; i >= 0; i--) {
            if (digits[i] > x.digits[i]) return sgn;
            if (digits[i] < x.digits[i]) return sgn ^ 1;
        }
        return 0;
    }

#ifndef BIGINTEGER_DISABLE_FFTMUL
    BigInteger fft_mul(const BigInteger& other) const {
        std::string astr = to_str(), bstr = other.to_str();
        int n = astr.size() - 1, m = bstr.size() - 1;
        std::vector<int> a, b;
        a.resize(n + 1), b.resize(m + 1);
        for (int i = 0; i <= n; i++) a[n - i] = astr[i] ^ 48;
        for (int i = 0; i <= m; i++) b[m - i] = bstr[i] ^ 48;

        std::vector<int> mul = BigInteger_FFT::mul(a, b);
        int lim = mul.size() - 1;
        for (int i = 0; i < lim; i++) mul[i + 1] += mul[i] / 10, mul[i] %= 10;
        for (; mul[lim]; lim++) mul.emplace_back(mul[lim] / 10), mul[lim] %= 10;

        std::string s;
        s.reserve(lim);
        for (int i = lim - 1; i >= 0; i--) s.push_back(mul[i] ^ 48);
        BigInteger res(s);
        res.sign = !(sign ^ other.sign);
        return res;
    }
#endif // BIGINTEGER_DISABLE_FFTMUL

public:
    BigInteger() : sign(true) {}
    BigInteger(const std::string& s)  {*this = s;}
    BigInteger(const long long& x)  {*this = x;}
    BigInteger(const BigInteger& x) {*this = x;}

    void clear() {sign = true, digits.clear();}
    void reserve(std::size_t k) {digits.reserve(k);}

    BigInteger& operator = (const std::string& s) {
        digits.clear(), sign = true;
        if (s.size() == 0) return *this;
        if (s == "-") return *this;
        int i = 0;
        if (s[0] == '-') sign = false, i++;
        digits.reserve(s.size() / WIDTH + 1);
        for (int j = s.size() - 1; j >= i; j -= WIDTH) {
            int start = std::max(i, j - WIDTH + 1), len = j - start + 1;
            digits.emplace_back(std::stoi(s.substr(start, len)));
        }
        if (digits.size() == 1 && digits[0] == 0) digits.pop_back(), sign = true;
        return *this;
    }
    BigInteger& operator = (const long long& x) {
        digits.clear(), sign = (x >= 0);
        if (x == 0) return *this;
        if (x == LONG_LONG_MIN) return *this = "-9223372036854775808";
        long long n = (x < 0 ? -x : x);
        do {
            digits.emplace_back(n % BASE);
            n /= BASE;
        } while (n);
        return *this;
    }
    BigInteger& operator = (const BigInteger& x) = default;
    ~BigInteger() = default;

    friend std::ostream& operator << (std::ostream& out, const BigInteger& x) {
        if (x.digits.empty()) {out << 0; return out;}
        if (!x.sign) out << '-';

        out << x.digits.back();
        for (int i = x.digits.size() - 2; i >= 0; i--) out << std::setw(WIDTH) << std::setfill('0') << x.digits[i];
        return out;
    }
    friend std::istream& operator >> (std::istream& in, BigInteger& x) {
        std::string s; in >> s; x = s; return in;
    }

    std::string to_str() const {
        std::stringstream ss;
        ss << *this;
        return ss.str();
    }
    std::vector<int> to_vint() const {
        if (iszero()) return {0};
        std::string str = to_str();
        std::vector<int> res(str.size());
        for (int i = str[0] == '-' ? 1 : 0; 
            i < (int) str.size(); i++) res[i] = str[i] ^ 48;
        return res;
    }

    BigInteger abs() const {
        BigInteger res = *this; res.sign = true; return res;
    }
    BigInteger operator - () const {
        BigInteger res = *this; res.sign = sign ^ 1; return res;
    }
    bool iszero() const {
        return digits.empty() || (digits.size() == 1 && digits.back() == 0);
    }

    BigInteger mul2() const {
        BigInteger res;
        int n = digits.size();
        res.digits.resize(n + 1);
        for (int i = 0; i < n; i++) {
            res.digits[i] = digits[i] << 1;
            if (i != 0 && res.digits[i - 1] >= BASE) {
                res.digits[i - 1] -= BASE;
                res.digits[i]++;
            }
        }
        while (!res.digits.empty() && res.digits.back() == 0) res.digits.pop_back();
        return res;
    }
    BigInteger div2() const {
        BigInteger res = *this;
        int n = digits.size();
        for (int i = n - 1; i >= 0; i--) {
            if (i > 0 && (res.digits[i] & 1)) res.digits[i - 1] += BASE;
            res.digits[i] >>= 1;
        }
        while (!res.digits.empty() && res.digits.back() == 0) res.digits.pop_back();
        return res;
    }

    bool operator < (const BigInteger& x) const {return compare(x) < 0;}
    bool operator > (const BigInteger& x) const {return compare(x) > 0;}
    bool operator == (const BigInteger& x) const {return compare(x) == 0;}
    bool operator != (const BigInteger& x) const {return compare(x) != 0;}
    bool operator <= (const BigInteger& x) const {return compare(x) <= 0;}
    bool operator >= (const BigInteger& x) const {return compare(x) >= 0;}

    BigInteger operator + (const BigInteger& x) const {
        if (sign && !x.sign) return *this - x.abs();
        if (!sign && x.sign) return x - abs();
        BigInteger res;
        res.sign = !(sign ^ x.sign);
        digit_t carry = 0;
        int n = std::max(digits.size(), x.digits.size()) + 1;
        res.digits.reserve(n);
        for (int i = 0; i < n; i++) {
            digit_t d1 = i < (int) digits.size() ? digits[i] : 0,
            d2 = i < (int) x.digits.size() ? x.digits[i] : 0;
            res.digits.emplace_back(d1 + d2 + carry);
            carry = res.digits[i] / BASE;
            res.digits[i] %= BASE;
        }
        while (!res.digits.empty() && res.digits.back() == 0) res.digits.pop_back();
        return res;
    }

    BigInteger operator - (const BigInteger& x) const {
        if (!x.sign) return *this + x.abs();
        if (!sign) return -(*this + x);
        BigInteger res;
        if (*this < x) res.sign = false;
        digit_t carry = 0;
        int n = std::max(digits.size(), x.digits.size());
        res.digits.reserve(n);
        for (int i = 0; i < n; i++) {
            digit_t d1 = i < (int) digits.size() ? digits[i] : 0,
            d2 = i < (int) x.digits.size() ? x.digits[i] : 0;
            if (res.sign) res.digits.emplace_back(d1 - d2 - carry);
            else res.digits.emplace_back(d2 - d1 - carry);
            if (res.digits[i] < 0) res.digits[i] += BASE, carry = 1;
            else carry = 0;
        }
        while (!res.digits.empty() && res.digits.back() == 0) res.digits.pop_back();
        return res;
    }

    BigInteger operator ++ () {return *this = *this + 1LL;}
    BigInteger operator -- () {return *this = *this - 1LL;}

    BigInteger operator * (const BigInteger& x) const {
        if (iszero() || x.iszero()) return BigInteger(0);
        int n = digits.size(), m = x.digits.size();
#ifndef BIGINTEGER_DISABLE_FFTMUL
        if ((long long) n * m >= FFT_LIMIT) return fft_mul(x);
#endif

        BigInteger res;
        res.sign = !(sign ^ x.sign);
        res.digits.resize(n + m + 2);
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                res.digits[i + j] += digits[i] * x.digits[j];
                res.digits[i + j + 1] += res.digits[i + j] / BASE;
                res.digits[i + j] %= BASE;
            }
        }
        for (int i = 0; i <= n + m; i++) {
            res.digits[i + 1] += res.digits[i] / BASE;
            res.digits[i] %= BASE;
        }
        while (!res.digits.empty() && res.digits.back() == 0) res.digits.pop_back();
        return res;
    }

    BigInteger operator * (const int& x) const {
        if (x >= BASE * 10) return *this * BigInteger(x);

        BigInteger res;
        res.sign = !(sign ^ (x >= 0));
        int c = x >= 0 ? x : -x;
        int n = digits.size();
        res.digits.resize(n + 1);
        for (int i = 0; i < n; i++) {
            res.digits[i] = digits[i] * c;
            if (res.digits[i] >= BASE) {
                res.digits[i + 1] += res.digits[i] / BASE;
                res.digits[i] %= BASE;
            }
        }
        while (!res.digits.empty() && res.digits.back() == 0) res.digits.pop_back();
        return res;
    }

    BigInteger operator / (const long long& x) const {
        if (x == 0) throw -1;
        if (iszero()) return BigInteger(0);
        BigInteger res;
        res.sign = !(sign ^ (x >= 0));

        digit_t cur = 0, div = x >= 0 ? x : -x;
        int n = digits.size();
        res.digits.resize(n);
        for (int i = n - 1; i >= 0; i--) {
            cur = cur * BASE + digits[i];
            res.digits[i] = cur / div;
            cur %= div;
        }
        while (!res.digits.empty() && res.digits.back() == 0) res.digits.pop_back();
        return res;
    }

    BigInteger operator / (const BigInteger& x) const {
        bool sgn = !(sign ^ x.sign);
        if (x == *this) return sgn ? 1 : -1;
        if (x > *this) return 0;
        BigInteger a = abs(), b = x.abs();
        int n = a.digits.size(), m = b.digits.size();
        BigInteger l = pow_base(std::max(0, n - m - 1)), 
                   r = std::min(pow_base(n - m + 1), *this);
        while (l <= r) {
            BigInteger mid = (l + r).div2();
            BigInteger now = mid * b;
            if (now <= a && now + b > a) {
                mid.sign = sgn;
                return mid;
            }
            if (now > a) r = mid - 1;
            else l = mid + 1;
        }
        l.sign = sgn;
        return l;
    }
    BigInteger operator % (const BigInteger& x) const {
        BigInteger c = *this / x * x;
        return *this - c;
    }

    BigInteger pow(const BigInteger& x) const {
        BigInteger res(1);
        BigInteger a = *this, b = x;
        for (; !b.iszero(); b = b.div2()) {
            if (b.mod2()) res *= a;
            a *= a;
        }
        return res;
    }
    BigInteger pow(const BigInteger& x, const BigInteger& p) const {
        BigInteger res(1);
        BigInteger a = *this, b = x;
        for (; !b.iszero(); b = b.div2()) {
            if (b.mod2()) res = res * a % p;
            a = a * a % p;
        }
        return res % p;
    }

    static BigInteger pow_base(const BigInteger& x) {
        BigInteger t; t = t.BASE;
        return t.pow(x);
    }

    int mod2() {
        if (iszero()) return 0;
        return digits[0] & 1;
    }

    long long log2() const {
        BigInteger two(2);
        for (long long res = 0; ; res <<= 1)
            if (two.pow(res) > *this) return res - 1;
        return -1;
    }
    BigInteger sqrt() const {
        if (*this < 0) throw -1;
        BigInteger x = *this;
        if (x <= 1) return x;
        BigInteger l = 1, r = x, res = 0;
        while (l <= r) {
            BigInteger mid = (l + r).div2();
            if (mid * mid <= x) l = mid + 1, res = mid;
            else r = mid - 1;
        }
        return res;
    }

    BigInteger& operator += (const BigInteger& x) {return *this = *this + x;}
    BigInteger& operator -= (const BigInteger& x) {return *this = *this - x;}
    BigInteger& operator *= (const BigInteger& x) {return *this = *this * x;}
    BigInteger& operator /= (const BigInteger& x) {return *this = *this / x;}
    BigInteger& operator %= (const BigInteger& x) {return *this = *this % x;}

    BigInteger& operator *= (const int& x) {return *this = *this * x;}
    BigInteger& operator /= (const long long& x) {return *this = *this / x;}
};

#endif //BIGINTEGER_H