P6736 「Wdsr-2」白泽教育

· · 题解

思路:

依次考虑三个问题。

n = 1

即求 a^x \equiv b \pmod p,直接 BSGS 即可。

时间复杂度为 O(T \sqrt{p})

n = 2

显然可以得到 a \uparrow \uparrow b = a^{a^{a^{\cdots}}},总共有 ba

这是幂塔的形势,考虑扩展欧拉定理:

a^b \equiv \begin{cases} a^b & b \le \varphi(p) \\ a^{b \bmod \varphi(p) + \varphi(p)} & b > \varphi(p) \end{cases}

对于这样的幂塔,显然可以递归的去算,每次 p \to \varphi(p),显然最多 \log p 次,p \to 1,那么更上层的一定没有必要了;这里设要 h 次。

所以对于方程 a \uparrow \uparrow x \equiv b \pmod p 时,可以得到 x 的大小是 0 \sim h 的,暴力枚举 x,然后根据扩展欧拉定理计算幂塔即可。

一次算 \varphi(p)O(\sqrt p) 的,要算 \log p 次;算一次高度为 x 的幂塔时间复杂度是 x \log p,于是总复杂度应该是 O(T(\sqrt p \log p + \log^3 p))

n = 3

根据递推关系,有:

a \uparrow \uparrow \uparrow x = a \uparrow \uparrow (a\uparrow \uparrow\uparrow (x - 1))

(a\uparrow \uparrow\uparrow (x - 1))a 的幂塔,也就是是说 x 的上界应该是满足 (a\uparrow \uparrow\uparrow (b - 1)) > h 的最小的 b;显然这个 x 应该是特别小的,不超过 4,你暴力递推计算 (a\uparrow \uparrow\uparrow (b - 1)) 就行。

时间复杂度类似,也是 O(T\sqrt p \log p)

完整代码:

#include<bits/stdc++.h>
#define lowbit(x) x & (-x)
#define pi pair<ll, ll>
#define ls(k) k << 1
#define rs(k) k << 1 | 1
#define fi first
#define se second
using namespace std;
typedef __int128 __;
typedef long double lb;
typedef double db;
typedef unsigned long long ull;
typedef long long ll;
const int N = 31, inf = 2e9;
inline ll read(){
    ll x = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9'){
        if(c == '-')
            f = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9'){
        x = (x << 1) + (x << 3) + (c ^ 48);
        c = getchar();
    }
    return x * f;
}
inline void write(ll x){
    if(x < 0){
        putchar('-');
        x = -x;
    }
    if(x > 9)
        write(x / 10);
    putchar(x % 10 + '0');
}
int T, a, n, b, p, cnt;
int Vmod[N];
unordered_map<int, int> mp;
inline int solve(int a, int b, int p){
    if(b == 1 || p == 1)
        return 0;
    mp.clear();
    int m = sqrt(p) + 1, now = 1;
    for(int i = 0; i < m; ++i){
        mp[1ll * b * now % p] = i;
        now = 1ll * now * a % p;
    }
    int x = now;
    now = 1;
    for(int i = 1; i <= m; ++i){
        now = 1ll * now * x % p;
        if(mp.count(now))
            return i * m - mp[now];
    }
    return -1;
}
inline int getphi(int n){
    int mul = 1;
    for(int i = 2; i * i <= n; ++i){
        if(n % i == 0){
            n /= i;
            mul *= (i - 1);
            while(n % i == 0){
                n /= i;
                mul *= i;
            }
        }
    }
    if(n > 1)
      mul *= (n - 1);
    return mul;
}
struct Node{
    int data;
    ll w; // if w > inf, w = -1
};
inline Node mul(Node a, Node b, int p){
    Node ans;
    ans.data = 1ll * a.data * b.data % p;
    if(a.w == -1 || b.w == -1)
      ans.w = -1;
    else{
        ans.w = a.w * b.w;
        if(ans.w > inf)
          ans.w = -1;
    }
    return ans;
}
inline Node qpow(int a, int b, int p){
    Node ans = {1, 1ll};
    Node A = {a, a};
    while(b){
        if(b & 1)
          ans = mul(ans, A, p);
        A = mul(A, A, p);
        b >>= 1;
    }
    return ans;
}
inline Node power(int a, int dep, int mod){
    if(!dep)
      return {1 % Vmod[mod], 1ll};
    if(dep == 1)
      return {a % Vmod[mod], a};
    if(mod == cnt)
      return {0, -1};
    auto t = power(a, dep - 1, mod + 1);
    if(t.w != -1 && t.w < Vmod[mod + 1])
      return qpow(a, t.data, Vmod[mod]);
    else
      return qpow(a, t.data + Vmod[mod + 1], Vmod[mod]);
}
inline void solve(){
    a = read(), n = read(), b = read(), p = read();
    if(n == 1){
        write(solve(a, b, p));
        putchar('\n');
        return ;
    }
    if(p == 1){
        puts("0");
        return ;
    }
    cnt = 0;
    int x = p;
    Vmod[++cnt] = x;
    while(x != 1){
        x = getphi(x);
        Vmod[++cnt] = x;
    }
    if(n == 2){
        for(int x = 0; x <= cnt; ++x){
            if(power(a, x, 1).data == b){
                write(x);
                putchar('\n');
                return ;
            }
        }
        puts("-1");
        return ;
    }
    if(b == 1){
        puts("0");
        return ;
    }
    if(a == 1){
        puts("-1");
        return ;
    }
    if(a % p == b % p){
        puts("1");
        return ;
    }
    if(power(a, a, 1).data == b){
        puts("2");
        return ;
    }
    auto t = power(a, a, 1);
    int dep = 0;
    if(t.w != -1 && t.w <= cnt) dep = t.w;
    else dep = cnt;
    if((t = power(a, dep, 1)).data == b){
        puts("3");
        return ;
    }
    dep = 0;
    if(t.w != -1 && t.w <= cnt) dep = t.w;
    else dep = cnt; 
    if(power(a, dep, 1).data == b){
        puts("4");
        return ;
    }
    puts("-1");
}
int main(){
    T = read();
    while(T--)
      solve();
    return 0;
}