[LOJ 6485 LJJ 学二项式定理] 题解/单位根反演小记

· · 题解

单位根反演

[a \equiv b \pmod n] = [a-b \equiv 0 \pmod n] = \frac{1}{n} \sum \limits_{k=0}^{n-1} \omega_{n}^{(a-b)k} = \frac{1}{n} \sum \limits_{k=0}^{n-1} \omega_{n}^{ak} \omega_{n}^{-bk}

LOJ 6485 LJJ 学二项式定理

给定 n, s, a_0, a_1, a_2, a_3,求

\sum \limits_{i = 0}^n \left( \binom{n}{i} \times s^i \times a_{i \bmod 4} \right)

998244353 的值。

有:

\begin{aligned} & \sum \limits_{i = 0}^n \left( \binom{n}{i} \times s^i \times a_{i \bmod 4} \right)\\ = & \sum \limits_{i = 0}^n \left( \binom{n}{i} \times s^i \sum \limits_{j=0}^3 a_j [i \equiv j \pmod 4] \right)\\ \end{aligned}

由单位根反演可知:

[i \equiv j \pmod 4] = \frac{1}{4} \sum \limits_{k=0}^3 \omega_{4}^{ik} \omega_{4}^{-jk}

所以:

\begin{aligned} & \sum \limits_{i = 0}^n \left( \binom{n}{i} \times s^i \sum \limits_{j=0}^3 a_j [i \equiv j \pmod 4] \right)\\ = & \sum \limits_{i = 0}^n \left( \binom{n}{i} \times s^i \sum \limits_{j=0}^3 a_j \times \frac{1}{4} \sum \limits_{k=0}^3 \omega_{4}^{ik} \omega_{4}^{-jk}\right)\\ = & \frac{1}{4} \sum \limits_{k=0}^3 \sum \limits_{j=0}^3 a_j \omega_4^{-jk} \sum \limits_{i=0}^n \binom{n}{i} s^i \omega_4^{ik}\\ = & \frac{1}{4} \sum \limits_{k=0}^3 \sum \limits_{j=0}^3 a_j \omega_4^{-jk} \sum \limits_{i=0}^n \binom{n}{i} (s \omega_4^k)^i \\ = & \frac{1}{4} \sum \limits_{k=0}^3 \sum \limits_{j=0}^3 a_j \omega_4^{-jk} (s \omega_4^k + 1)^n\\ \end{aligned}

(注:模 p 意义下 \omega_n^1 = g^{(mod-1)/n}

由于 998244353 的原根是 3,所以 w_4^1 = g^{(mod-1)/4}

然后就可以计算了。

#include <iostream>
#include <cstdio>
using namespace std;
typedef long long ll;

const ll mod = 998244353;

int T, w, s, a[4];
ll n, ans;

inline ll chk(ll x) {
    return x >= mod ? x - mod : x;
}

inline ll fpm(ll a, ll k = mod - 2) {
    ll res = 1;

    while (k) {
        if (k & 1)
            res = res * a % mod;

        a = a * a % mod;
        k >>= 1;
    }

    return res;
}

int main() {
    w = fpm(3, (mod - 1) / 4);

    scanf("%d", &T);

    while (T--) {
        scanf("%lld %d %d %d %d %d", &n, &s, &a[0], &a[1], &a[2], &a[3]);

        ans = 0;

        for (int k = 0; k < 4; k++)
            for (int j = 0; j < 4; j++)
                ans = chk(ans + 1ll * a[j] * fpm(fpm(w), j * k) % mod * fpm(1ll * s * fpm(w, k) % mod + 1, n) % mod);

        printf("%lld\n", 1ll * ans * fpm(4) % mod);
    }

    return 0;
}