求卡常

P5205 【模板】多项式开根

```cpp #include <bits/stdc++.h> using namespace std; typedef long long ll; const int N = 270010, LOGN = 23; const ll mod = 998244353, g = 114514, invg = 137043501; int n, m, rgt[N]; vector <ll> w1[LOGN], w2[LOGN]; ll power (ll a, ll b) { return b == 0? 1: (b & 1? a: 1) * power (a * a % mod, b >> 1) % mod; } void init (int n) { int k = 0; for (; (1 << k) <= n; k++); for (int i = 0; i < (1 << k); i++) { rgt[i] = (rgt[i >> 1] >> 1) | ((i & 1) << (k - 1)); } for (int i = 1; i <= k; i++) { w1[i].resize (1 << i); w1[i][0] = 1, w1[i][1] = power (g, (mod - 1) >> i); for (int j = 2; j < (1 << i); j++) { w1[i][j] = (w1[i][j - 1] * w1[i][1]) % mod; } w2[i].resize (1 << i); w2[i][0] = 1, w2[i][1] = power (invg, (mod - 1) >> i); for (int j = 2; j < (1 << i); j++) { w2[i][j] = (w2[i][j - 1] * w2[i][1]) % mod; } } } void dft (vector <ll>& a, int inv) { int n = a.size (); assert ((n & (-n)) == n); for (int i = 0; i < n; i++) { if (i < rgt[i]) { swap (a[i], a[rgt[i]]); } } for (int mid = 1, st = 1; mid < n; mid <<= 1, st++) { for (int i = 0; i < n; i += mid << 1) { ll* w = &(inv == 1? w1: w2)[st][0]; for (int j = 0; j < mid; j++) { ll x = a[i + j], y = *(w++) * a[i + j + mid] % mod; a[i + j] = (x + y) % mod; a[i + j + mid] = (x - y + mod) % mod; } } } if (inv == -1) { ll ninv = power (n, mod - 2); for (int i = 0; i < n; i++) { (a[i] *= ninv) %= mod; } } } struct poly { vector <ll> a; poly () {} poly (int n_) { a.resize (n_ + 1); } poly& operator = (int n_) { a.resize (n_ + 1); return *this; } int size () { return a.size () - 1; } ll& operator [] (int p) { return a[p]; } void suit () { int n = a.size () - 1; int k = 0; for (; (1 << k) <= n; k++); a.resize (1 << k); } }; poly pcopy (poly pre, int n_) { poly ans = pre; return ans = n_; } poly operator + (poly f, poly g) { int n = max (f.size (), g.size ()); f = n; g = n; for (int i = 0; i <= n; i++) { (f[i] += g[i]) %= mod; } return f; } poly operator - (poly f, poly g) { int n = max (f.size (), g.size ()); f = n; g = n; for (int i = 0; i <= n; i++) { ((f[i] -= g[i]) += mod) %= mod; } return f; } poly operator * (poly f, ll c) { for (int i = 0; i <= f.size (); i++) { (f[i] *= c) %= mod; } return f; } poly operator * (ll c, poly f) { return f * c; } poly operator / (poly f, ll c) { return f * power (c, mod - 2); } poly operator * (poly f, poly g) { poly flong = pcopy (f, f.size () + g.size ()); poly glong = pcopy (g, f.size () + g.size ()); flong.suit (); glong.suit (); init (flong.size ()); dft (flong.a, 1); dft (glong.a, 1); for (int i = 0; i <= flong.size (); i++) { (flong[i] *= glong[i]) %= mod; } dft (flong.a, -1); return flong = f.size () + g.size (); } poly derivative (poly f) { int n = f.size (); poly ans = n - 1; for (int i = 0; i < n; i++) { ans[i] = f[i + 1] * (i + 1) % mod; } return ans; } poly intergal (poly f) { int n = f.size (); poly ans = n + 1; for (int i = n + 1; i >= 1; i--) { ans[i] = f[i - 1] * power (i, mod - 2) % mod; } return ans; } poly inv (poly f) { int tmpn = f.size (); f.suit (); int limn = f.size (); poly ans = 0; ans[0] = power (f[0], mod - 2); for (int pw = 1; (1 << pw) - 1 <= limn; pw++) { poly f0 = ans; int n = (1 << pw) - 1, m = f0.size (); f0 = n + m; poly fl = pcopy (f, n + m); f0.suit (); fl.suit (); init (f0.size ()); dft (f0.a, 1); dft (fl.a, 1); for (int i = 0; i <= f0.size (); i++) { (f0[i] *= f0[i] * fl[i] % mod) %= mod; } dft (f0.a, -1); ans = n; for (int i = (n >> 1) + 1; i <= n; i++) { ans[i] = (mod - f0[i]) % mod; } } return ans = tmpn; } poly sqrt (poly f) { int tmpn = f.size (); f.suit (); int limn = f.size (); poly ans = 0; ans[0] = 1; for (int pw = 1; (1 << pw) - 1 <= limn; pw++) { poly f0 = ans; int n = (1 << pw) - 1; f0 = n; poly curr = (f0 + f * inv (f0)) / 2; ans = pcopy (curr, n); } return ans = tmpn; } poly ln (poly f) { poly ans = intergal (derivative (f) * inv (f)); return ans = f.size (); } poly a; int main () { scanf ("%d", &n); a = --n; for (int i = 0; i <= n; i++) { scanf ("%lld", &a[i]); } a = sqrt (a); for (int i = 0; i <= n; i++) { printf ("%lld%c", a[i], " \n"[i == n]); } return 0; } ```
by denominator @ 2024-01-07 16:25:30


@[denominator](/user/174009) 取模偏多了,比如 `dft` 里面这句 `y = *(w++) * a[i + j + mid] % mod;` 可以把 `% mod` 去掉,但是下面的 `a[i + j + mid] = (x - y + mod) % mod;` 要跟着改为 `a[i + j + mid] = rds((x-y)%mod);`。`rds` 长这样: ```cpp inline int rds(int x){ return x < 0 ? x + mod : x;} ``` 这样一改最慢点就是 650ms 左右了,虽然离目标还有段距离。 如果想继续卡的话,可能要把数组尽可能开成 `int`,这样会快不少。
by 飞雨烟雁 @ 2024-01-07 17:10:26


@[飞雨烟雁](/user/375984) thx,刚刚加了一堆无关紧要的优化(包括 `i++` 改为 `++i`,快读,把你的 `rds` 搞成宏),再减少取模次数,可以成功卡进 600ms 大关,但是不吸氧依然 T 飞,而且 1.20s,可能因为厌氧。
by denominator @ 2024-01-07 17:20:28


开 `long long` 和多取模我已经习惯了,要改改(((
by denominator @ 2024-01-07 17:21:10


@[denominator](/user/174009) 厌氧->UB,写错了。
by denominator @ 2024-01-07 17:21:43


|