题解 P5491 【【模板】二次剩余】

· · 题解

在写这道题前,我们需要了解一个方程。

x^2 \equiv n \pmod p

NR and QR

若对于奇素数 p 和满足 n \nmid p 的整数 n,该方程的 x 有整数解,则我们称 n 是关于模 p 的二次剩余,记做 QR 。而对于无整数解的称呼其为模 p 的二次非剩余。记做 NR

而当 n \equiv 0 \pmod p 的情况,也就是 n \mid p ,这种情况下 n 既不是 QR 也不是 NR

QR与NR的乘法法则

对于 QRNR ,我们有如下性质。

性质一:QR \times QR=QR

性质二:QR \times NR=NR

性质三:NR \times NR=QR

性质四:NR \times QR=NR

性质与 1-1 之间的性质很像,所以我们可以设 QR=1NR=-1

因为篇幅有限,性质的证明请自行Google。或阅读 《数论概论》

legendre 符号

所以对于QRNR,我们可以定义 legendre 符号。

\left(\frac{n}{p} \right)= \begin{cases} 1& \text{n 是模p的二次剩余}\\ -1& \text{n 是模p的二次非剩余} \end{cases}

欧拉准则

\left(\frac{n}{p} \right)\equiv n^{\frac{p-1}{2}} \pmod p

Cipolla 算法

我们可以找到一个整数 a,使得 \left(\frac{a^2-n}{p} \right)=-1,定义 w^2=a^2-n

然后 (a+w)^{\frac{p+1}{2}} 教室原方程的一个解,而因为另一个解是在模p意义下对称的,所以就为 p-(a+w)^{\frac{p+1}{2}}

具体证明自行Google.

实现代码

#include <cstdio>
#include <cstdlib>
#include <ctime>

struct complex {
    long long x, y;
};

long long w;

complex mul(complex a, complex b, long long p) {
    complex ans = {0, 0};
    ans.x = (((a.x * b.x) % p + (a.y * b.y) % p * w % p) % p + p) % p;
    ans.y = (((a.x * b.y) % p + (a.y * b.x) % p) % p + p) % p;
    return ans;
}

long long complexPow(complex a, long long b, long long p) {
    complex ans = {1, 0};
    for (; b; b >>= 1) {
        if (b & 1) ans = mul(ans, a, p);
        a = mul(a, a, p);
    }
    return ans.x % p;
}

long long mypow(long long a, long long b, long long p) {
    long long ans = 1;
    for (; b; b >>= 1) {
        if (b & 1)
            ans = (ans * a) % p;
        a = (a * a) % p;
    }
    return ans;
}

long long cipolla(long long n, long long p) {
    n %= p;
    if (p == 2) 
        return n;
    if (mypow(n, (p - 1) / 2, p) == p - 1) 
        return -1;
    long long a = 1;
    while (1) {
        a = rand() % p;
        w = ((a * a % p - n) % p + p) % p;
        if (mypow(w, (p - 1) / 2, p) == p - 1) 
            break;
    }
    complex x = {a, 1};
    return complexPow(x, (p + 1) / 2, p);
}

int main() {
    srand((unsigned)time(0));
    int t;
    long long n, p;
    scanf("%d", &t);
    while (t--) {
        scanf("%lld %lld", &n, &p);
        long long k1 = cipolla(n, p);
        if (k1 == -1) {
            printf("Hola!\n");
            continue;
        }
        long long k2 = p - k1;
        if (k1 > k2)
            k1 ^= k2 ^= k1 ^= k2;
        if (k1 == k2)
            printf("%lld\n", k1);
        else
            printf("%lld %lld\n", k1, k2);
    }
    return 0;
}