题解:P10446 64位整数乘法

· · 题解

题目传送门

这道题其实不难,难就难在如何处理乘法过程中溢出的问题。对于这个问题,这里给出两种方法。

方法 #1

首先,根据同余相关理论,有:

根据第二条,我们可以将原来的求 (a\times b) \bmod p 转化为求 (a\bmod p)\times(b\bmod p) \bmod p,进而防止溢出。

然而题目给的模数太大,还是无法避免溢出。

于是又根据第一条,我们可以将 b 拆成若干个整数相加的形式,分别取模再相加,最后再相乘取模。这样,溢出问题总算可以解决了。

但是,新的问题产生了:我们应该把 b 拆成哪几个整数呢?

答案是拆成几个不同的 2 的整数次幂之和,即对 b 进行二进制拆分。

举个例子吧。比如 (514 \times 114) \bmod p = (514 \bmod p) \times (2^6 + 2^5 + 2^4 + 2^1) \bmod p = (514 \bmod p) \times [(2^6 \bmod p) + (2^5 \bmod p) + (2^4 \bmod p) + (2^1 \bmod p)] \bmod p

为什么呢?其实也很显然,毕竟计算机内部就是按照二进制的方式存储的嘛,这样既方便我们操作——使用位运算,又不会拆得太多或太少,使我们能在可以接受的时间内得出结果而不会溢出,一举多得。

这样,我们就得到了大名鼎鼎的龟速乘算法,时间复杂度 \Theta(\log b)

核心代码如下:

#define int long long
int mul(int a,int b,int p){
    int res=0;
    for(;b;b>>=1){
        if(b&1) res=(res+a)%p;
        a=(a<<1)%p;
    }
    return res;
}
//有没有发现长得很像快速幂?
/*
int qpow(int a,int b,int p){
    int res=1%p;
    for(;b;b>>=1){
        if(b&1) res=res*a%p;
        a=a*a%p;
    }
    return res;
}
*/

方法 #2

再介绍一种玄学的 \Theta(1) 快速乘。

首先还是根据上面的第二条,把 ab 都先对 p 取模。

然后根据取模的定义,有 (a \times b) \bmod p = a \times b - \lfloor \displaystyle \frac{a \times b}{p} \rfloor \times p

对于 \lfloor \displaystyle \frac{a \times b}{p} \rfloor,我们可以将其强制转换成 long double 类型再进行运算。这是因为 long double 最大位数可达到 18 左右,而且 long double 在精度不够时会自动丢弃掉最低位,刚好符合我们下取整的要求。

对于其他部分,我们直接用 long long 计算即可。可是这样不是溢出了吗?没错,确实会。但是注意,这里 a \times b - \lfloor \displaystyle \frac{a \times b}{p} \rfloor \times p 的结果一定是小于 p 的,也就是说,结果只与模数的位数内的数字有关,超过的那部分对结果并没有影响。

例如,对于 a = 114514b = 1145141919810p = 123456 的情况,a \times b = 131134781805122340,而 \lfloor \displaystyle \frac{a \times b}{p} \rfloor \times p = 131134781805111552。看到了吧,第六位以外的数字都相等,一减就没了,留着他们也没用。

核心代码如下:

#define int long long
int mul(int a,int b,int p){
    a%=p,b%=p;
    int c=(long double)a*b/p;
    int res=a*b-c*p;
    if(res<0) res+=p;
    else if(res>=p) res-=p;
    return res;
}

不过听说毕竟浮点数运算难免有误差(虽然可能性不大),所以除非万不得已,还是老老实实写龟速乘吧。

最后贴一下其他部分的代码;

#include <bits/stdc++.h>
using namespace std;
template <typename T> void read(T &x){
    x=0;
    short _=1;
    char c=getchar();
    for(;!isdigit(c);c=getchar()) if(c=='-') _=-_;
    for(;isdigit(c);c=getchar()) x=(x<<1)+(x<<3)+(c^48);
    x*=_;
}
template <typename T> void write(T x){
    if(x<0) putchar('-'),x=-x;
    short _=0;
    char s[20];
    do{s[++_]=x%10;x/=10;}while(x>0);
    while(_>0) putchar(s[_--]^48);
}
int a,b,p;
signed main(){
    read(a); read(b); read(p);
    write(mul(a,b,p));
    return 0;
}