压位高精度模板(带FFT)
stripe_python · · 个人记录
更新
2023/12/31 更新:
- 增加了
sqrt函数
2024/01/01 更新:
- 更正了
compare函数中关于负数讨论的问题(感谢 @Alex_Wei 巨佬) - 增加了
to_vint函数 - 更正了 FFT 乘法负数的问题
- 更正了减法关于负数讨论的问题(感谢 @Alex_Wei 巨佬 & 谢罪)
2024/02/05 更新:
- 经过对拍,除法负数有锅。其他暂没有发现问题。
2024/02/23 更新:
- 开学前打了乘积最大,发现板子还有锅。由于马上开学,没有时间 debug,本版子暂时作废。
保证
已通过所有高精度模板题(包括 FFT 板子)。
与 python 在 python。
详情见 性能测试 部分。
实现细节
对于高精度运算
压位位数
由于 FFT 做压位的上限太低,且朴素乘法有
复杂度:
| 操作 | 函数名 | 时间 | 空间 |
|---|---|---|---|
| 字符串初始化 | / | ||
| 整型初始化 | / | ||
| 复制初始化 | / | ||
| 清空(置 |
clear |
||
| 预先分配空间 | reserve |
||
| 读入 | / | ||
| 输出 | / | ||
| 取绝对值 | abs |
||
| 取反 | / | ||
| 判断是否为 |
iszero |
||
| 乘 |
mul2 |
||
| 除 |
div2 |
||
| 比较运算符 | / | ||
| 加 | / | ||
| 减 | / | ||
| 自增 | / | ||
| 自减 | / | ||
| 乘高精(朴素) | / | ||
| 乘高精(FFT) | / | ||
| 乘单精 | / | ||
| 除高精(朴素) | / | ||
| 除高精(FFT) | / | ||
| 除单精 | / | ||
| 取余(朴素) | / | ||
| 取余(FFT) | / | ||
| 乘方(朴素) | pow |
||
| 乘方(FFT) | pow |
||
| 进制乘方 | pow_base |
||
| 模 |
mod2 |
||
| 平方根(朴素) | sqrt |
||
| 平方根(FFT) | sqrt |
性能测试
统计图(
附上性能测试代码:
//对拍&测试
#include <bits/stdc++.h>
#include <windows.h>
using namespace std;
random_device rd;
mt19937 ran(rd());
const int DATA = 2e4;
void gen() {
int n = ran() % DATA;
putchar((ran() % 9 + 1) ^ 48);
for (int i = 1; i < n; i++) {
putchar((ran() % 10) ^ 48);
}
putchar(' ');
putchar((ran() % 9 + 1) ^ 48);
for (int i = 1; i < n; i++) {
putchar((ran() % 10) ^ 48);
}
}
double aaa, bbb;
#define _begin aaa = clock()
#define _end bbb = clock()
#define _time (bbb - aaa)
int main() {
freopen("result/result-2e4-div.csv", "w", stderr);
cerr << "State,Python-time,BigInteger-time" << endl;
int T = 1000;
while (T--) {
freopen("data.in", "w", stdout);
gen();
freopen("details.out", "w", stdout);
_begin; system("python std.py"); _end;
double py = _time;
_begin; system("BigInteger.exe < data.in > biginteger.out"); _end;
double big = _time;
if (system("fc py.out biginteger.out")) {
cerr << "WA,0.0,0.0" << endl;
break;
}
cerr << "AC," << py << "," << big << endl;
}
return 0;
}
# 画图
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
NAME = ['add', 'sub', 'div', 'mul', 'mod']
FILE = 'result-{data}-{name}.csv'
DATA = ['2e3', '2e4', '2e5']
from collections import defaultdict
import numpy as np
k = defaultdict(list)
for n in NAME:
for d in DATA:
f = FILE.replace('{data}', d).replace('{name}', n)
with open(f, 'r', encoding='utf-8') as fp:
c = fp.read()
a = c.split('\n')[1001]
x, y = map(float, a.split(',')[1:])
k[n].append([x, y])
K = 'mod'
xtick = ['2e3', '2e4', '2e5']
py = np.array([i[0] for i in k[K]])
big = np.array([i[1] for i in k[K]])
print(x, py, big)
n_groups = 3
index = np.arange(n_groups)
bar_width = 0.35
r1 = plt.bar(index, big, bar_width, color='b', label='BigInteger')
r2 = plt.bar(index + bar_width, py, bar_width, color='g', label='Python')
plt.title('高精度取余の性能统计图')
plt.xlabel('数据范围')
plt.ylabel('运行时间(ms)')
plt.xticks(index + bar_width, xtick)
plt.legend(loc='upper left')
plt.show()
以及高精板子题的测试结果作为参考(测试点平均运行时间/最大运行时间):
| 题目 | 平均运行时间/ms | 最大运行时间/ms |
|---|---|---|
| P1601 (加法) | 3.4 | 4 |
| P2142 (减法) | 3.6 | 4 |
| P1303 (朴素乘法) | 3.0 | 3 |
| P1919 (FFT) | 915.0 | 948 |
| P1480 (高精除单精) | 3.0 | 3 |
| P2005 (高精除高精) | 194.5 | 727 |
| P1932 (全套) | 75.7 | 481 |
大模板
400 行,写了半个月。
#ifndef BIGINTEGER_H
#define BIGINTEGER_H
#include <vector>
#include <string>
#include <climits>
#include <iomanip>
#include <sstream>
#ifndef BIGINTEGER_DISABLE_FFTMUL
#include <cmath>
namespace BigInteger_FFT {
const double PI = std::acos(-1.0);
class complex {
public:
double real, imag;
complex(double x = 0.0, double y = 0.0) : real(x), imag(y) {}
complex operator + (const complex& other) const {
return complex(real + other.real, imag + other.imag);
}
complex operator - (const complex& other) const {
return complex(real - other.real, imag - other.imag);
}
complex operator * (const complex& other) const {
return complex(real * other.real - imag * other.imag, real * other.imag + other.real * imag);
}
};
std::vector<complex> f, g;
void fft(std::vector<complex>& a, int len, double type) {
if (len == 1) return;
std::vector<complex> fl((len >> 1) + 1), fr((len >> 1) + 1);
for (int i = 0; i <= len; i++) {
if (i & 1) fr[i >> 1] = a[i];
else fl[i >> 1] = a[i];
}
fft(fl, len >> 1, type), fft(fr, len >> 1, type);
complex w1 = complex(cos(2.0 * PI / len), sin(2.0 * PI / len) * type), w = complex(1.0, 0.0);
for (int i = 0; i < (len >> 1); i++, w = w * w1) {
a[i] = fl[i] + fr[i] * w;
a[i + (len >> 1)] = fl[i] - fr[i] * w;
}
}
std::vector<int> mul(const std::vector<int>& a, const std::vector<int>& b) {
int n = a.size() - 1, m = b.size() - 1;
int len = 1;
while (len <= n + m) len <<= 1;
f.clear(), f.resize(len + 1);
g.clear(), g.resize(len + 1);
for (int i = 0; i <= n; i++) f[i] = complex(a[i], 0.0);
for (int i = 0; i <= m; i++) g[i] = complex(b[i], 0.0);
fft(f, len, 1.0), fft(g, len, 1.0);
for (int i = 0; i <= len; i++) f[i] = f[i] * g[i];
fft(f, len, -1.0);
std::vector<int> res;
res.reserve(n + m + 1);
for (int i = 0; i <= n + m; i++) res.emplace_back(f[i].real / len + 0.5);
return res;
}
} // namespace BigInteger_FFT
#endif // BIGINTEGER_DISABLE_FFTMUL
class BigInteger {
private:
using digit_t = long long;
static const int WIDTH = 8;
static const digit_t BASE = 1e8;
static const long long FFT_LIMIT = 65536;
//static const long long NEWTON_LIMIT = 65536;
std::vector<digit_t> digits;
bool sign;
int compare(const BigInteger& x) const {
if (sign && !x.sign) return 1;
if (!sign && x.sign) return -1;
if (iszero() && x.iszero()) return 0;
if (iszero()) return x.sign ^ 1;
if (x.iszero()) return sign;
int sgn = (sign && x.sign ? 1 : -1);
if (digits.size() > x.digits.size()) return sgn;
if (digits.size() < x.digits.size()) return sgn ^ 1;
for (int i = digits.size() - 1; i >= 0; i--) {
if (digits[i] > x.digits[i]) return sgn;
if (digits[i] < x.digits[i]) return sgn ^ 1;
}
return 0;
}
#ifndef BIGINTEGER_DISABLE_FFTMUL
BigInteger fft_mul(const BigInteger& other) const {
std::string astr = to_str(), bstr = other.to_str();
int n = astr.size() - 1, m = bstr.size() - 1;
std::vector<int> a, b;
a.resize(n + 1), b.resize(m + 1);
for (int i = 0; i <= n; i++) a[n - i] = astr[i] ^ 48;
for (int i = 0; i <= m; i++) b[m - i] = bstr[i] ^ 48;
std::vector<int> mul = BigInteger_FFT::mul(a, b);
int lim = mul.size() - 1;
for (int i = 0; i < lim; i++) mul[i + 1] += mul[i] / 10, mul[i] %= 10;
for (; mul[lim]; lim++) mul.emplace_back(mul[lim] / 10), mul[lim] %= 10;
std::string s;
s.reserve(lim);
for (int i = lim - 1; i >= 0; i--) s.push_back(mul[i] ^ 48);
BigInteger res(s);
res.sign = !(sign ^ other.sign);
return res;
}
#endif // BIGINTEGER_DISABLE_FFTMUL
public:
BigInteger() : sign(true) {}
BigInteger(const std::string& s) {*this = s;}
BigInteger(const long long& x) {*this = x;}
BigInteger(const BigInteger& x) {*this = x;}
void clear() {sign = true, digits.clear();}
void reserve(std::size_t k) {digits.reserve(k);}
BigInteger& operator = (const std::string& s) {
digits.clear(), sign = true;
if (s.size() == 0) return *this;
if (s == "-") return *this;
int i = 0;
if (s[0] == '-') sign = false, i++;
digits.reserve(s.size() / WIDTH + 1);
for (int j = s.size() - 1; j >= i; j -= WIDTH) {
int start = std::max(i, j - WIDTH + 1), len = j - start + 1;
digits.emplace_back(std::stoi(s.substr(start, len)));
}
if (digits.size() == 1 && digits[0] == 0) digits.pop_back(), sign = true;
return *this;
}
BigInteger& operator = (const long long& x) {
digits.clear(), sign = (x >= 0);
if (x == 0) return *this;
if (x == LONG_LONG_MIN) return *this = "-9223372036854775808";
long long n = (x < 0 ? -x : x);
do {
digits.emplace_back(n % BASE);
n /= BASE;
} while (n);
return *this;
}
BigInteger& operator = (const BigInteger& x) = default;
~BigInteger() = default;
friend std::ostream& operator << (std::ostream& out, const BigInteger& x) {
if (x.digits.empty()) {out << 0; return out;}
if (!x.sign) out << '-';
out << x.digits.back();
for (int i = x.digits.size() - 2; i >= 0; i--) out << std::setw(WIDTH) << std::setfill('0') << x.digits[i];
return out;
}
friend std::istream& operator >> (std::istream& in, BigInteger& x) {
std::string s; in >> s; x = s; return in;
}
std::string to_str() const {
std::stringstream ss;
ss << *this;
return ss.str();
}
std::vector<int> to_vint() const {
if (iszero()) return {0};
std::string str = to_str();
std::vector<int> res(str.size());
for (int i = str[0] == '-' ? 1 : 0;
i < (int) str.size(); i++) res[i] = str[i] ^ 48;
return res;
}
BigInteger abs() const {
BigInteger res = *this; res.sign = true; return res;
}
BigInteger operator - () const {
BigInteger res = *this; res.sign = sign ^ 1; return res;
}
bool iszero() const {
return digits.empty() || (digits.size() == 1 && digits.back() == 0);
}
BigInteger mul2() const {
BigInteger res;
int n = digits.size();
res.digits.resize(n + 1);
for (int i = 0; i < n; i++) {
res.digits[i] = digits[i] << 1;
if (i != 0 && res.digits[i - 1] >= BASE) {
res.digits[i - 1] -= BASE;
res.digits[i]++;
}
}
while (!res.digits.empty() && res.digits.back() == 0) res.digits.pop_back();
return res;
}
BigInteger div2() const {
BigInteger res = *this;
int n = digits.size();
for (int i = n - 1; i >= 0; i--) {
if (i > 0 && (res.digits[i] & 1)) res.digits[i - 1] += BASE;
res.digits[i] >>= 1;
}
while (!res.digits.empty() && res.digits.back() == 0) res.digits.pop_back();
return res;
}
bool operator < (const BigInteger& x) const {return compare(x) < 0;}
bool operator > (const BigInteger& x) const {return compare(x) > 0;}
bool operator == (const BigInteger& x) const {return compare(x) == 0;}
bool operator != (const BigInteger& x) const {return compare(x) != 0;}
bool operator <= (const BigInteger& x) const {return compare(x) <= 0;}
bool operator >= (const BigInteger& x) const {return compare(x) >= 0;}
BigInteger operator + (const BigInteger& x) const {
if (sign && !x.sign) return *this - x.abs();
if (!sign && x.sign) return x - abs();
BigInteger res;
res.sign = !(sign ^ x.sign);
digit_t carry = 0;
int n = std::max(digits.size(), x.digits.size()) + 1;
res.digits.reserve(n);
for (int i = 0; i < n; i++) {
digit_t d1 = i < (int) digits.size() ? digits[i] : 0,
d2 = i < (int) x.digits.size() ? x.digits[i] : 0;
res.digits.emplace_back(d1 + d2 + carry);
carry = res.digits[i] / BASE;
res.digits[i] %= BASE;
}
while (!res.digits.empty() && res.digits.back() == 0) res.digits.pop_back();
return res;
}
BigInteger operator - (const BigInteger& x) const {
if (!x.sign) return *this + x.abs();
if (!sign) return -(*this + x);
BigInteger res;
if (*this < x) res.sign = false;
digit_t carry = 0;
int n = std::max(digits.size(), x.digits.size());
res.digits.reserve(n);
for (int i = 0; i < n; i++) {
digit_t d1 = i < (int) digits.size() ? digits[i] : 0,
d2 = i < (int) x.digits.size() ? x.digits[i] : 0;
if (res.sign) res.digits.emplace_back(d1 - d2 - carry);
else res.digits.emplace_back(d2 - d1 - carry);
if (res.digits[i] < 0) res.digits[i] += BASE, carry = 1;
else carry = 0;
}
while (!res.digits.empty() && res.digits.back() == 0) res.digits.pop_back();
return res;
}
BigInteger operator ++ () {return *this = *this + 1LL;}
BigInteger operator -- () {return *this = *this - 1LL;}
BigInteger operator * (const BigInteger& x) const {
if (iszero() || x.iszero()) return BigInteger(0);
int n = digits.size(), m = x.digits.size();
#ifndef BIGINTEGER_DISABLE_FFTMUL
if ((long long) n * m >= FFT_LIMIT) return fft_mul(x);
#endif
BigInteger res;
res.sign = !(sign ^ x.sign);
res.digits.resize(n + m + 2);
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
res.digits[i + j] += digits[i] * x.digits[j];
res.digits[i + j + 1] += res.digits[i + j] / BASE;
res.digits[i + j] %= BASE;
}
}
for (int i = 0; i <= n + m; i++) {
res.digits[i + 1] += res.digits[i] / BASE;
res.digits[i] %= BASE;
}
while (!res.digits.empty() && res.digits.back() == 0) res.digits.pop_back();
return res;
}
BigInteger operator * (const int& x) const {
if (x >= BASE * 10) return *this * BigInteger(x);
BigInteger res;
res.sign = !(sign ^ (x >= 0));
int c = x >= 0 ? x : -x;
int n = digits.size();
res.digits.resize(n + 1);
for (int i = 0; i < n; i++) {
res.digits[i] = digits[i] * c;
if (res.digits[i] >= BASE) {
res.digits[i + 1] += res.digits[i] / BASE;
res.digits[i] %= BASE;
}
}
while (!res.digits.empty() && res.digits.back() == 0) res.digits.pop_back();
return res;
}
BigInteger operator / (const long long& x) const {
if (x == 0) throw -1;
if (iszero()) return BigInteger(0);
BigInteger res;
res.sign = !(sign ^ (x >= 0));
digit_t cur = 0, div = x >= 0 ? x : -x;
int n = digits.size();
res.digits.resize(n);
for (int i = n - 1; i >= 0; i--) {
cur = cur * BASE + digits[i];
res.digits[i] = cur / div;
cur %= div;
}
while (!res.digits.empty() && res.digits.back() == 0) res.digits.pop_back();
return res;
}
BigInteger operator / (const BigInteger& x) const {
bool sgn = !(sign ^ x.sign);
if (x == *this) return sgn ? 1 : -1;
if (x > *this) return 0;
BigInteger a = abs(), b = x.abs();
int n = a.digits.size(), m = b.digits.size();
BigInteger l = pow_base(std::max(0, n - m - 1)),
r = std::min(pow_base(n - m + 1), *this);
while (l <= r) {
BigInteger mid = (l + r).div2();
BigInteger now = mid * b;
if (now <= a && now + b > a) {
mid.sign = sgn;
return mid;
}
if (now > a) r = mid - 1;
else l = mid + 1;
}
l.sign = sgn;
return l;
}
BigInteger operator % (const BigInteger& x) const {
BigInteger c = *this / x * x;
return *this - c;
}
BigInteger pow(const BigInteger& x) const {
BigInteger res(1);
BigInteger a = *this, b = x;
for (; !b.iszero(); b = b.div2()) {
if (b.mod2()) res *= a;
a *= a;
}
return res;
}
BigInteger pow(const BigInteger& x, const BigInteger& p) const {
BigInteger res(1);
BigInteger a = *this, b = x;
for (; !b.iszero(); b = b.div2()) {
if (b.mod2()) res = res * a % p;
a = a * a % p;
}
return res % p;
}
static BigInteger pow_base(const BigInteger& x) {
BigInteger t; t = t.BASE;
return t.pow(x);
}
int mod2() {
if (iszero()) return 0;
return digits[0] & 1;
}
long long log2() const {
BigInteger two(2);
for (long long res = 0; ; res <<= 1)
if (two.pow(res) > *this) return res - 1;
return -1;
}
BigInteger sqrt() const {
if (*this < 0) throw -1;
BigInteger x = *this;
if (x <= 1) return x;
BigInteger l = 1, r = x, res = 0;
while (l <= r) {
BigInteger mid = (l + r).div2();
if (mid * mid <= x) l = mid + 1, res = mid;
else r = mid - 1;
}
return res;
}
BigInteger& operator += (const BigInteger& x) {return *this = *this + x;}
BigInteger& operator -= (const BigInteger& x) {return *this = *this - x;}
BigInteger& operator *= (const BigInteger& x) {return *this = *this * x;}
BigInteger& operator /= (const BigInteger& x) {return *this = *this / x;}
BigInteger& operator %= (const BigInteger& x) {return *this = *this % x;}
BigInteger& operator *= (const int& x) {return *this = *this * x;}
BigInteger& operator /= (const long long& x) {return *this = *this / x;}
};
#endif //BIGINTEGER_H