逆元

· · 个人记录

逆元

这次我们来谈论一下逆元的几种求法

逆元的概念

逆元(Inverse element)就是在mod意义下,不能直接除以一个数,而要乘以它的逆元

比如a \times b \equiv 1 \pmod p

那么a,b互为模n意义下的逆元(a和p必须互质才存在逆元),可以理解为b就是模意义下的\frac1b \pmod p

比如你要算x/a,就可以改成x \times b \mod p

### 注意 在以下的所有方法中,如果想让结果是一个正数(最小) 那么就要用`x=(x%mod+mod)%mod` 这样就可以让x变成正数 ## 方法一:扩展欧几里得求逆元 #### 扩展欧几里得 这个是用来求解二元一次方程的经典方法 它可以找到整数x,y,满足$ax+by=gcd(a,b)

下面给出证明:

\begin{cases} x=1\\ y=0 \end{cases}

x=y1,y=x1-a/b\times y1,运用递归,因为所有x,y都可以由上一步得到了ax+by=gcd(a,b)

代码如下:

int exgcd(int a,int b,int &x,int &y)//返回的是a,b的最大公约数
{
    if(b==0)
    {
        x=1;
        y=0;
        return 1;
    }
    int g=exgcd(b,a%b,x,y),z=x;
    x=y;
    y=z-(a/b)*y;
    return g;
}

求逆元

ax\equiv 1\pmod p中a的逆元就如下:

exgcd(a,p,x,y);
printf("%d",x);

方法二:费马小定理求逆元

p为质数,且a,p互质,则a^{p-1}\equiv 1\pmod p

所以a^{p-2}就是a在mod p意义下的逆元

代码如下(一般配合快速幂食用):

int fastpow(int a,int b,int mod)
{
    int ans=1;
    while(b)
    {
        if(b&1)
            a=(a*b)%mod;
        b=b*b%mod;
        b>>=1;
    }
    return ans;
}
int getinv(int a,int p,int mod)
{
    return fastpow(a,p-2,mod);
}

方法三:线性求逆元

t=p/i,k=p\%i,则p=it+k\\

\begin{aligned} it+k &\equiv 0 \pmod p\\ it &\equiv -k \pmod p \end{aligned}

i',k'分别为i,k \ mod\ p意义下的逆元

\begin{aligned} it \times i'k' &\equiv -k \times i'k' \pmod p\\ tk' &\equiv -i' \pmod p\\ i' &\equiv -tk' \pmod p \end{aligned}

将t,k带入,用inv[]数组表示逆元,并取正

\begin{aligned} inv[i]&=-(p/i)\ inv[p\%i] \ \%p\\ &=(p-p/i)\ inv[p\%i]\ \%p \end{aligned}

代码如下:

int i;
inv[1]=1;
for(i=2;i<=N;i++)
{
    inv[i]=(mod-mod/i)*inv[mod%i]%mod
}

当然了也可以把这个改造成递归的形式

int inv(int a)
{
    if(a==1)
        return 1;
    return (mod-mod/i)*inv(mod%i)%mod
}

求阶乘的逆元

inv[i]$表是i mod p意义下的逆元,可以理解为$\frac1i \pmod p

所以求出阶乘的最高项的逆元就好递推出其他的逆元了,可以使用前面的三种方法

int i;
for(i=1;i<=N;i++)
{
    f[i]=f[i-1]*i%mod;//阶乘
}
inv[N]=getinv(f[N]);
for(i=N-1;i>=1;i--)
{
    inv[i]=inv[i+1]*i;//阶乘的逆元
}

逆元求组合数

组合数C^m_n=\frac{n!}{m!(n-m)!}约分后得到\prod^{m-1}_{i=0}\frac{n-i}{i+1}

所以代码如下

int C(int n,int m)
{
    if(n<0||m<0||n<m)
        return 0;
    if(n==0||m==0)
        return 1;
    int i,ans=1;
    for(i=1;i<=m;i++)
    {
        ans=ans*(n-i+1)%mod*inv[i]%mod;
    }
    return ans;
}