GDOI 2021 PJ D2T2
你管这叫普及?
设
主要转移是考虑最大的
考虑一些特殊情况:只有两个数一个最高位
考虑可能没有逆元,一种简单的方式是
时间复杂度
参考实现:
#include <bits/stdc++.h>
typedef long long ll;
const int N = 125;
int lg2(ll n) { return 63 - __builtin_clzll(n); }
int f(ll n, int l, int r);
int g(ll n, int l, int r);
void inc(int &a, int b);
void dec(int &a, int b);
int _f[N][60][60], _g[N][60][60];
std::map<ll, int> mp;
int tot;
int pw2[120];
ll n;
int Mod;
int main() {
scanf("%lld %d", &n, &Mod);
pw2[0] = 1;
for (int i = 1; i < 120; ++i)
pw2[i] = (pw2[i - 1] << 1) % Mod;
for (int i = 1; i <= 120; ++i) {
for (int j = 0; j < 60; ++j) {
for (int k = 0; k < 60; ++k) {
_f[i][j][k] = _g[i][j][k] = -1;
}
}
}
int ans = (n + 1) % Mod;
for (int i = 0; i < 60; ++i) {
for (int j = 0; j < 60; ++j) {
ans = (ans + f(n, i, j)) % Mod;
}
}
printf("%d\n", ans);
return 0;
}
void inc(int &a, int b) { (a += b) >= Mod ? a -= Mod : a; }
void dec(int &a, int b) { (a -= b) < 0 ? a += Mod : a; }
int f(ll n, int l, int r) {
if (n == 0) return 0;
if ((1ll << std::max(l, r)) > n) return 0;
if (!mp.count(n)) mp[n] = ++tot;
int now = mp[n];
if (_f[now][l][r] != -1) return _f[now][l][r];
int w = lg2(n);
ll L = (1ll << w) - 1, R = n - (1ll << w);
int ans = (f(L, l, r) + f(R, l, r)) % Mod;
if (l == w && r == w) inc(ans, 1ll * pw2[w] * ((R + 1) % Mod) % Mod);
if (l == w) {
for (int d = 0; d < w; ++d)
inc(ans, 1ll * pw2[w] * f(R, d, r) % Mod);
}
if (r == w) {
for (int d = 0; d < w; ++d)
inc(ans, 1ll * f(L, l, d) * ((R + 1) % Mod) % Mod);
}
int s1 = 0;
for (int d1 = 0; d1 < w; ++d1)
inc(s1, f(L, l, d1));
int s2 = 0;
for (int d2 = 0; d2 < w; ++d2)
inc(s2, f(R, d2, r));
inc(ans, 1ll * s1 * s2 % Mod);
for (int d = 0; d < w; ++d) {
dec(ans, 1ll * f(L, l, d) * f(R, d, r) % Mod);
inc(ans, 1ll * g(L, l, d) * f(R, d, r) % Mod * (pw2[d] - 1) % Mod);
}
return _f[now][l][r] = ans;
}
int g(ll n, int l, int r) {
if (n == 0) return 0;
if ((1ll << std::max(l, r)) > n) return 0;
if (!mp.count(n)) mp[n] = ++tot;
int now = mp[n];
if (_g[now][l][r] != -1) return _g[now][l][r];
int w = lg2(n);
ll L = (1ll << w) - 1, R = n - (1ll << w);
int ans = (g(L, l, r) + g(R, l, r)) % Mod;
if (l == w && r == w) inc(ans, pw2[w]);
if (l == w) {
for (int d = 0; d < w; ++d)
inc(ans, 1ll * pw2[w] * g(R, d, r) % Mod);
}
if (r == w) {
for (int d = 0; d < w; ++d)
inc(ans, 1ll * f(L, l, d) * (((R + 1) >> w) % Mod) % Mod);
}
int s1 = 0;
for (int d1 = 0; d1 < w; ++d1)
inc(s1, f(L, l, d1));
int s2 = 0;
for (int d2 = 0; d2 < w; ++d2)
inc(s2, g(R, d2, r));
inc(ans, 1ll * s1 * s2 % Mod);
for (int d = 0; d < w; ++d) {
dec(ans, 1ll * f(L, l, d) * g(R, d, r) % Mod);
inc(ans, 1ll * g(L, l, d) * g(R, d, r) % Mod * (pw2[d] - 1) % Mod);
}
return _g[now][l][r] = ans;
}