Barrett Reduction 乘法取模加速

· · 个人记录

今年写 USACO Open 的时候看到了一种神奇的乘法取模加速方法,下面简单叙述。

使用条件:

  1. 计算 a\times b \bmod{p},其中 0\le a,b<p,且 p 是一个编译时未知的常数,例如题目输入的(需要多次取模的)模数。如果 p 在编译时已知,可以由编译器完成编译时优化。

  2. 如果 pint 级别的数,那么就需要使用 __int128

这种算法的主要思想是这样的。

通常,我们(人工)计算取模,用的是 a\bmod b=a-\left\lfloor \dfrac{a}{p} \right\rfloor \cdot p。这个计算中有除法,开销比较大,我们希望能用乘法替换除法,计算出 q=\left\lfloor \dfrac{a}{p} \right\rfloor

首先我们想到一种方法,即把式子写成 q=\lfloor a\cdot p^{-1} \rfloor,其中 p^{-1} 是用浮点数表示的 p 的倒数。但浮点数计算仍然缓慢,我们还需要再优化。

这时候就有神仙提出:我们可以钦定一个整数 k,再弄出一个整数 m,使得 \dfrac{m}{2^k}\approx\dfrac{1}{p},那么 q 不就约等于 \dfrac{a\cdot m}{2^k} 了吗?这样除法运算就被拆成了一次乘法和一次位移,速度大大加快了。

具体怎么求 m 呢?由于 \dfrac{m}{2^k}\approx \dfrac{1}{p},因此 m\approx \dfrac{2^k}{p}。只需要预处理出这个除法的结果,将来就可以一劳永逸,不需要除法了。

据说是为了防止算出的商超过实际的商,我们一般取 m=\left\lfloor \dfrac{2^k}{p} \right\rfloor

那么 k 取多大呢?假设 2^k\approx p,那么 m 的可能取值就很有限,导致算出的 q 与实际的商有较大偏离。

这里,我们取 k\ge \lceil 2\log_2 p \rceil,也就是使得 2^k\approx p^2。下面我们证明,这样取 k 时,0\le a-pq<p,也就是我们稍后在计算余数 a-pq 时,得到的答案至多需要再做一次减法 不需要再调整。

由于 q=\dfrac{am}{2^k},因此 pq=\dfrac{apm}{2^k}a-pq=\dfrac{a}{2^k}\cdot (2^k-pm)

由于 m=\left\lfloor \dfrac{2^k}{p} \right\rfloor,因此 \dfrac{2^k}{p}-1<m\le \dfrac{2^k}{p},于是 2^k-p<pm\le 2^k,即 0 \le 2^k-pm< p

如果 k\ge \lceil 2\log_2 p \rceil,那么就有 2^k\ge p^2;考虑到 a\bmod{p} 下两个数的乘积,不会达到 p^2,因此 0<\dfrac{a}{2^k}<1;于是 a-pq=\dfrac{a}{2^k}\cdot (2^k-pm) 就在 [0,p) 的范围内了。

总结这个算法的过程如下:

  1. 根据 p 的规模选取合适的 k,一般要求 k\ge \lceil 2\log_2 p \rceil

  2. 根据 k,p 预处理出 m=\left\lfloor \dfrac{2^k}{p} \right\rfloor

  3. 实际计算时,用 q=\dfrac{a\cdot m}{2^k} 计算出商,再用 r=a-pq 得出余数。

哪里需要用到 __int128 呢?如果 \log_2 p\approx 32,那么就有 k\approx 2^{64}, m\approx 2^{32};而 a\approx p^2\approx 2^{64},因此计算 a\cdot m 时需要用到 __int128。此外,如果 k 取得稍大(大于等于 63),那么在计算 m 时也需使用 __int128

参考实现如下:

struct Kasumul{
    ll p,m; //p 表示上面的模数, m 为取模参数
    //构造一个模数为 p 的取模器
    Kasumul(ll tp):p(tp),m(ll((lll(1)<<63)/tp)){}
    //calc(x) 即计算 x%p 的值
    ll calc(ll x){
        ll q=((lll(x)*m)>>63);
        ll aans=x-q*p;
        MD(aans); //这行有一定必要,参见 LOJ #3300. 「联合省选 2020 A」组合数问题
        return aans;
    }
};

Kasumul muler(2); //因为没有 void 构造函数,所以需要用参数 2 来初始化

//在 main 函数中
const int p=10000007;
muler=Kasumul(p);

//使用时
int a=1235123,b=72343451,ans;
//下面这行代码相当于 ans=1ll*a*b%p;
ans=muler.calc(1ll*a*b);

实际测试选用了 【ZJOI2010】排列计数,使用上述优化的程序的运行时间约为通常程序的 70 \%