多项式模板库

· · 个人记录

似乎常数巨大……

namespace MyPoly {

constexpr int mod = 998244353, g = 3, ginv = (mod + 1) / 3;

struct poly {

vector <int> f;

poly() {
    f.resize(1, 0);
}

int qpow(int x, int p) {
    int ans = 1;
    for (; p; p /= 2, x = 1ll * x * x % mod) {
        if (p & 1) ans = 1ll * ans * x % mod;
    }
    return ans;
}

int size() {
    return f.size() - 1;
}

int at(int n) {
    if (n <= size()) return f[n];
    else return 0;
}

void set(int n, int t) {
    if (n > size()) f.resize(n + 1, 0);
    f[n] = t;
}

void resize(int n) {
    if (n > size()) f.resize(n + 1, 0);
}

void shrink() {
    while (size() >= 1 && !f[ size() ]) f.pop_back();
}

int add(int x, int y) {
    int t = x + y;
    if (t >= mod) return t - mod;
    else return t;
}

int sub(int x, int y) {
    int t = x - y;
    if (t < 0) return t + mod;
    else return t;
}

friend poly operator + (poly a, poly b) {
    poly c;
    for (int i = 0; i <= max(a.size(), b.size()); ++i) {
        c.set(i, c.add(a.at(i), b.at(i)));
    }
    c.shrink();
    return c;
}

friend poly operator - (poly a, poly b) {
    poly c;
    for (int i = 0; i <= max(a.size(), b.size()); ++i) {
        c.set(i, c.sub(a.at(i), b.at(i)));
    }
    c.shrink();
    return c;
}

friend poly operator % (poly a, int n) {
    for (int i = n; i <= a.size(); ++i) a.set(i, 0);
    a.shrink();
    return a;
}

void ntt(poly &a, int inv, vector <int> &rev) {
    int n = rev.size();
    for (int i = 0; i < n; ++i) if (i < rev[i]) swap(a.f[i], a.f[ rev[i] ]);
    for (int m = 1; m < n; m <<= 1) {
        int gn = qpow(inv == 1 ? g : ginv, (mod - 1) / (m << 1));
        for (int i = 0; i < n; i += m << 1) {
            for (int j = 0, g0 = 1; j < m; ++j, g0 = 1ll * g0 * gn % mod) {
                int g1 = a.at(i + j), g2 = 1ll * g0 * a.at(i + j + m) % mod;
                a.set(i + j, a.add(g1, g2)), a.set(i + j + m, a.sub(g1, g2));
            }
        }
    }
    if (inv == -1) {
        int r = qpow(n, mod - 2);
        for (int i = 0; i < n; ++i) a.set(i, 1ll * a.at(i) * r % mod);
    }
}

friend poly operator * (poly a, poly b) {
    int n = a.size(), m = b.size(), l = 1, lg = 0;
    while (l <= n + m) l <<= 1, ++lg;
    a.resize(l), b.resize(l);
    vector <int> rev(l, 0);
    for (int i = 1; i < l; ++i) {
        rev[i] = (rev[ i >> 1 ] >> 1) | ((i & 1) << (lg - 1));
    }
    poly c;
    c.ntt(a, 1, rev), c.ntt(b, 1, rev);
    for (int i = 0; i < l; ++i) c.set(i, 1ll * a.at(i) * b.at(i) % mod);
    c.ntt(c, -1, rev);
    c.shrink();
    return c;
}

friend poly operator * (poly a, int m) {
    for (int i = 0; i <= a.size(); ++i) {
        a.set(i, 1ll * a.at(i) * m % mod);
    }
    a.shrink();
    return a;
}

friend void operator += (poly &a, poly b) {
    a = a + b;
}

friend void operator -= (poly &a, poly b) {
    a = a - b;
}

friend void operator *= (poly &a, poly b) {
    a = a * b;
}

friend void operator *= (poly &a, int m) {
    a = a * m;
}

friend void operator %= (poly &a, int n) {
    a = a % n;
}

poly Inv(poly a) {
    poly h;
    h.set(0, qpow(a.at(0), mod - 2));
    int w = 1;
    while (w <= a.size()) {
        w *= 2;
        poly f;
        for (int i = 0; i < w; ++i) f.set(i, a.at(i));
        poly g = h * 2 - f * h * h;
        h = g % w;
    }
    return h % (a.size() + 1);
}

poly Der(poly a) { // f(x) -> f'(x)
    poly b;
    for (int i = 0; i < a.size(); ++i) b.set(i, 1ll * (i + 1) * a.at(i + 1) % mod);
    b.shrink();
    return b;
}

poly Int(poly a) {
    poly b;
    vector <int> inv(a.size() + 1, 1);
    for (int i = 2; i <= a.size(); ++i) {
        int q = mod / i, r = mod % i;
        inv[i] = b.sub(0, 1ll * inv[r] * q % mod);
    }
    for (int i = 1; i <= a.size(); ++i) b.set(i, 1ll * inv[i] * a.at(i - 1) % mod);
    b.shrink();
    return b;
}

poly Ln(poly a) {
    poly f = a.Der(a), g = a.Inv(a);
    f *= g, f = f.Int(f);
    f %= a.size() + 1, f.shrink();
    return f;
}

poly Exp(poly a) {
    poly h;
    h.set(0, 1);
    int w = 1;
    while (w <= a.size()) {
        w *= 2;
        poly b;
        for (int i = 0; i < w; ++i) b.set(i, a.at(i));
        poly o; o.set(0, 1);
        h.resize(w);
        poly g = h * (o - h.Ln(h) + b);
        h = g % w;
    }
    return h % (a.size() + 1);
}

};

} using namespace MyPoly;