Index Calculus 算法

· · 题解

  1. 简介

Index Calculus 算法是一种基于原根的快速计算模数为质数的离散对数的算法。

  1. 算法

前置芝士:原根、欧拉定理、高斯消元、线性同余方程

2.1. 预处理(一)

选定一个阈值 N,线性筛出 N 以内的质数。时间复杂度为 O(N)

论文告诉我们:Ne^{\frac{\sqrt{\ln p \ln \ln p}}{2} + 1} 时较优。

2.2. 预处理(二)

设模数为 p

求出 p 的一个原根 g,枚举 / 随机化 0 \leq x < p,令 y = g^x \bmod p,再求出其标准分解式 y = \displaystyle\prod_{i = 1}^k P_i^{q_i}。若 y 的标准分解式中仅含 \leq N 的质数,则加入方程 \displaystyle\sum_{i = 1}^k q_i z_{P_i} = x。加入 \alpha \pi(n) 个(这里的 \alpha 可以依情况调整,一般取 3 \sim 4 时表现较优)方程后停止并进行 \bmod \varphi(p) 意义下的高斯消元。重复执行本过程直到高斯消元成功。

2.3. 算法过程(一)

设方程为 g^x \equiv a \pmod p

枚举 / 随机化 0 \leq y < p,令 z = ag^y \bmod p,再求出其标准分解式 z = \displaystyle\prod_{i = 1}^k P_i^{q_i}。若 z 的标准分解式中仅含 \leq N 的质数,则 x \equiv \displaystyle\sum_{i = 1}^k q_i z_{P_i} - y \pmod {\varphi(p)}。根据原根的性质可得此时 x 为最小非负整数解。

2.4. 算法过程(二)

设方程为 a^x \equiv b \pmod p

首先用上面的算法求出满足 g^y \equiv a \pmod pg^z \equiv b \pmod p 的最小 x, y,代入原方程可得:g^{xy} \equiv g^z \pmod p

由于底数相同,把指数拿出来,有:xy \equiv z \pmod {\varphi(p)}

用线性同余方程解出 x 的最小非负整数解即可。

综上,令 \beta = \log_N p,时间复杂度为 O(p^{\frac{1}{4}} + \alpha^{\alpha} \beta(\frac{N}{\ln N})^2 + \beta(\frac{N}{\ln N})^3 \log p) \sim O(\frac{N \alpha^{\alpha}}{\ln N})

注:在实际实现时,由于 rand() 常数较大,所以如果采用随机化,建议直接对 g^x 依次求幂,找到任意合法情况后退出。

  1. 例题

本题在前述过程的基础上还需要 Pollard-Rho 算法和一些玄学剪枝。

代码:

#include <iostream>
#include <map>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <ctime>

using namespace std;

typedef long long ll;
typedef unsigned long long ull;
typedef __int128 lll;

const int N = 717 + 7, M = 239 + 7;

typedef struct Matrix_tag {
    int n;
    int m;
    ll a[N][M];
    Matrix_tag(){
        memset(a, 0, sizeof(a));
    }
} Matrix;

const int K = 12, P = 1504 + 7, Q = (1 << 6) - 1, S = 1 << 7;
int test_prime[K + 7] = {0, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37};
ll prime[M], ind[M];
ull inv_prime[M];
bool p[P];
map<ll, int> mp;

inline ull inv(ull x){
    ull ans = 1;
    for (register int i = 1; i <= Q; i++){
        ans *= x;
        x *= x;
    }
    return ans;
}

inline int init1(ll mod){
    int cnt = 0;
    ll limit = exp(sqrt(log(mod) * log(log(mod))) / 2.0 + 1.0);
    p[0] = p[1] = true;
    for (register ll i = 2; i <= limit; i++){
        if (!p[i]){
            cnt++;
            prime[cnt] = i;
            inv_prime[cnt] = inv(i);
        }
        for (register int j = 1; j <= cnt && i * prime[j] <= limit; j++){
            p[i * prime[j]] = true;
            if (i % prime[j] == 0) break;
        }
    }
    return cnt;
}

inline ll rand64(){
#if RAND_MAX == 0x7fff
    return (ll)rand() * rand() * rand() * rand();
#else
    return (ll)rand() * rand();
#endif
}

inline ll quick_pow(ll x, ll p, ll mod){
    ll ans = 1;
    while (p){
        if (p & 1) ans = (lll)ans * x % mod;
        x = (lll)x * x % mod;
        p >>= 1;
    }
    return ans;
}

ll gcd(ll a, ll b){
    return b == 0 ? a : gcd(b, a % b);
}

void exgcd(ll a, ll b, ll &x, ll &y){
    if (b == 0){
        x = 1;
        y = 0;
        return;
    }
    ll t;
    exgcd(b, a % b, x, y);
    t = x;
    x = y;
    y = t - a / b * y;
}

inline ll inv(ll a, ll b){
    ll x, y;
    exgcd(a, b, x, y);
    return (x % b + b) % b;
}

inline bool gauss(Matrix &a, ll mod){
    int ni = a.n + 1;
    for (register int i = 1; i <= a.n; i++){
        if (gcd(a.a[i][i], mod) != 1){
            int t = i;
            for (register int j = i + 1; j <= a.m; j++){
                if (gcd(a.a[j][i], mod) == 1){
                    t = j;
                    break;
                }
            }
            if (i == t) return false;
            swap(a.a[i], a.a[t]);
        }
        ll cur_inv = inv(a.a[i][i], mod);
        for (register int j = i + 1; j <= a.m; j++){
            if (a.a[j][i] != 0){
                ll t = (lll)cur_inv * a.a[j][i] % mod;
                for (register int k = i; k <= ni; k++){
                    a.a[j][k] = ((a.a[j][k] - (lll)t * a.a[i][k] % mod) % mod + mod) % mod;
                }
            }
        }
    }
    for (register int i = a.n; i >= 1; i--){
        for (register int j = i + 1; j <= a.n; j++){
            a.a[i][ni] = ((a.a[i][ni] - (lll)a.a[i][j] * a.a[j][ni] % mod) % mod + mod) % mod;
        }
        a.a[i][ni] = (lll)a.a[i][ni] * inv(a.a[i][i], mod) % mod;
    }
    return true;
}

inline void init2(ll g, ll p, int cnt){
    int limit = cnt * 3, cnt_i = cnt + 1;
    ll phi_p = p - 1;
    Matrix a;
    a.n = cnt;
    a.m = limit;
    do {
        int i = 0;
        for (register ll j = rand64() % phi_p, k = quick_pow(g, j, p); i < limit; j = rand64() % phi_p, k = quick_pow(g, j, p)){
            for (register ll l = k, x = 1; x < phi_p; l = (lll)l * k % p, x++){
                int t1 = __builtin_ctzll(l);
                ull t2 = l >> t1;
                for (register int y = 2; y <= cnt && t2 > 1; y++){
                    if ((y == 10 && t2 > 1e15) || (y == 30 && t2 > 1e12)) break;
                    while (t2 * inv_prime[y] < t2) t2 *= inv_prime[y];
                }
                if (t2 == 1){
                    i++;
                    a.a[i][1] = t1;
                    t2 = l >> t1;
                    for (register int y = 2; y <= cnt; y++){
                        a.a[i][y] = 0;
                        while (t2 % prime[y] == 0){
                            t2 *= inv_prime[y];
                            a.a[i][y]++;
                        }
                    }
                    if (t2 == 1){
                        a.a[i][cnt_i] = (lll)j * x % phi_p;
                        break;
                    } else {
                        i--;
                    }
                }
            }
        }
    } while (!gauss(a, phi_p));
    for (register int i = 1; i <= cnt; i++){
        ind[i] = a.a[i][cnt_i];
    }
}

inline bool is_prime(ll n){
    if (n < 2) return false;
    int cnt = 0;
    ll m = n - 1, k = m;
    while (!(k & 1)){
        k >>= 1;
        cnt++;
    }
    for (register int i = 1; i <= K; i++){
        if (n == test_prime[i]) return true;
        ll a = quick_pow(test_prime[i], k, n), b = a;
        for (register int j = 1; j <= cnt; j++){
            b = (lll)b * a % n;
            if (b == 1 && a != 1 && a != m) return false;
            a = b;
        }
        if (a != 1) return false;
    }
    return true;
}

inline ll floyd(ll a, ll b, ll p){
    return ((lll)a * a % p + b) % p;
}

inline ll abs64(ll n){
    return n >= 0 ? n : -n;
}

inline ll pollard_pho(ll n){
    ll x = 0, c = rand() % n;
    for (register int i = 1; ; i <<= 1){
        ll y = 1, z = x;
        for (register int j = 1; j <= i; j++){
            x = floyd(x, c, n);
            y = (lll)y * abs64(x - z) % n;
            if (j == i || j % S == 0){
                ll ans = gcd(n, y);
                if (ans > 1) return ans;
            }
        }
    }
}

void decompound(ll n){
    if (n < 2) return;
    if (is_prime(n)){
        mp[n]++;
        return;
    }
    ll factor;
    do {
        factor = pollard_pho(n);
    } while (factor == n);
    decompound(factor);
    while (n % factor == 0){
        n /= factor;
    }
    decompound(n);
}

inline ll get_least_primitive_root(ll n){
    ll phi_n = n - 1;
    decompound(phi_n);
    for (register ll i = 0; i < n; i++){
        if (gcd(i, n) != 1) continue;
        bool flag = true;
        for (register map<ll, int>::iterator j = mp.begin(); j != mp.end(); j++){
            if (quick_pow(i, phi_n / j->first, n) == 1){
                flag = false;
                break;
            }
        }
        if (flag) return i;
    }
    return -1;
}

inline ll index_calculus(ll a, ll g, ll p, int cnt){
    a %= p;
    if (a == 1) return 0;
    if (p == 2) return -1;
    ll phi_p = p - 1;
    for (register ll i = rand64() % phi_p, j = quick_pow(g, i, p); ; i = rand64() % phi_p, j = quick_pow(g, i, p)){
        for (register ll k = j, l = 1; l < phi_p; k = (lll)k * j % p, l++){
            int t1;
            ull t2 = (lll)a * k % p, t3;
            t1 = __builtin_ctzll(t2);
            t2 >>= t1;
            t3 = t2;
            for (register int x = 2; x <= cnt && t2 > 1; x++){
                if ((x == 10 && t2 > 1e15) || (x == 30 && t2 > 1e12)) break;
                while (t2 * inv_prime[x] < t2) t2 *= inv_prime[x];
            }
            if (t2 == 1){
                ll ans = (((lll)t1 * ind[1] % phi_p - (lll)i * l % phi_p) % phi_p + phi_p) % phi_p;
                for (register int x = 2; x <= cnt && t3 > 1; x++){
                    while (t3 % prime[x] == 0){
                        t3 *= inv_prime[x];
                        ans = (ans + ind[x]) % phi_p;
                    }
                }
                if (t3 == 1) return ans;
            }
        }
    }
}

inline ll dlog(ll a, ll b, ll g, ll p, int cnt){
    a %= p;
    b %= p;
    if (b == 1) return 0;
    ll x = index_calculus(a, g, p, cnt), phi_p = p - 1, d = gcd(x, phi_p), y = index_calculus(b, g, p, cnt);
    if (y % d != 0) return -1;
    phi_p /= d;
    return ((lll)(y / d) * inv(x / d, phi_p) % phi_p + phi_p) % phi_p;
}

int main(){
    int t, cnt;
    ll p, g;
    cin >> t >> p;
    cnt = init1(p);
    srand(time(NULL));
    g = get_least_primitive_root(p);
    init2(g, p, cnt);
    for (register int i = 1; i <= t; i++){
        ll a, b;
        cin >> a >> b;
        cout << dlog(a, b, g, p, cnt) << endl;
    }
    return 0;
}