[AGC038C] LCMs 题解

· · 题解

一道比较简单的莫反练习题。

题意

给定长度为 n 的数列 a_1, a_2, a_3, ..., a_n,求:

\left( \sum_{i=1}^n \sum_{j=i+1}^n {\rm lcm}(a_i, a_j) \right) \bmod 998244353

题解

首先,考虑把问题转化一下。设 sum = \sum_{i=1}^n \sum_{j=1}^n {\rm lcm}(a_i, a_j),那么有

\begin{aligned} \sum_{i=1}^n \sum_{j=i+1}^n {\rm lcm}(a_i, a_j) = \frac{sum - \sum_{i=1}^n a_i}{2} \end{aligned}

所以考虑如何求出 sum。发现直接处理 a_i 并不好做,但是观察到 1 \leq a_i \leq 10^6,考虑给每个数字开一个桶然后计数。不妨用 c_i 表示数字 i 出现的次数,用 m 表示最大值,那么有:

sum = \sum_{i=1}^m \sum_{j=1}^m {\rm lcm}(i, j) \times c_i \times c_j

所以可以莫反:

\begin{aligned} sum &= \sum_{i=1}^m \sum_{j=1}^m \frac{ij}{\gcd(i, j)} \times c_i \times c_j \end{aligned}

枚举 \gcd(i, j)

\begin{aligned} & \sum_{d=1}^n \sum_{i=1}^{\lfloor \frac{m}{d} \rfloor} \sum_{j=1}^{\lfloor \frac{m}{d} \rfloor} \frac{ijd^2}{d} \times c_{id} \times c_{jd} \ [\gcd(i, j) = 1]\\ =& \sum_{d=1}^n \sum_{i=1}^{\lfloor \frac{m}{d} \rfloor} \sum_{j=1}^{\lfloor \frac{m}{d} \rfloor} ijd \times c_{id} \times c_{jd} \ [\gcd(i, j) = 1]\\ =& \sum_{d=1}^n \sum_{g=1}^{\lfloor \frac{m}{d} \rfloor} \mu(g) \sum_{i=1}^{\lfloor \frac{m}{dg} \rfloor} \sum_{j=1}^{\lfloor \frac{m}{dg} \rfloor} ijdg^2 \times c_{idg} \times c_{jdg} \end{aligned}

T = dg

\begin{aligned} & \sum_{T=1}^n \sum_{d | T} \mu \left( \frac{T}{d} \right) \sum_{i=1}^{\lfloor \frac{m}{T} \rfloor} \sum_{j=1}^{\lfloor \frac{m}{T} \rfloor} \frac{ijT^2}{d} \times c_{iT} \times c_{jT}\\ =& \sum_{T=1}^n \sum_{d | T} \mu \left( \frac{T}{d} \right) \frac{T^2}{d} \sum_{i=1}^{\lfloor \frac{m}{T} \rfloor} \sum_{j=1}^{\lfloor \frac{m}{T} \rfloor} i \times j \times c_{iT} \times c_{jT}\\ =& \sum_{T=1}^n \sum_{d | T} \mu \left( \frac{T}{d} \right) \frac{T^2}{d} \left( \sum_{i=1}^{\lfloor \frac{m}{T} \rfloor} i \times c_{iT} \right)^2\\ \end{aligned}

f(T) = \sum_{d | T} \mu \left( \frac{T}{d} \right) \frac{T^2}{d}, g(T) = \sum_{i=1}^{\lfloor \frac{m}{T} \rfloor} i \times c_{iT},这两个都可以在 \mathcal{O}(n \ln n) 的时间复杂度做出来,那么最后:

sum = \sum_{T=1}^n f(T) g^2(T)

CODE:

#include <iostream>
#include <cstdio>
using namespace std;
typedef long long ll;

const int maxn = 1e6 + 5;
const int maxk = 1e6;
const ll mod = 998244353;

int n, m, tot;
int prime[maxn>>1], mu[maxn];
ll c[maxn], f[maxn], g[maxn];
bool not_prime[maxn];

ll fpm(ll a, ll k) {
    ll res = 1;
    while(k) {
        if(k&1) res = res*a % mod;
        a = a*a % mod;
        k >>= 1;
    }
    return res;
}
const ll inv2 = fpm(2, mod-2);

void prework() {
    mu[1] = 1;
    for(int i = 2; i <= n; i++) {
        if(!not_prime[i]) prime[++tot] = i, mu[i] = -1;
        for(int j = 1; j <= tot && i*prime[j] <= n; j++) {
            not_prime[i*prime[j]] = true;
            if(i%prime[j] == 0) {
                mu[i*prime[j]] = 0;
                break;
            }
            mu[i*prime[j]] = -mu[i];
        }
    }
    for(int i = 1; i <= n; i++)
        for(int j = 1; i*j <= n; j++)
            f[i*j] = (f[i*j] + (1ll * mu[j] * i * j % mod * j % mod + mod) % mod) % mod;
    for(int i = 1; i <= n; i++)
        f[i] = (f[i] + f[i-1]) % mod;
    for(int T = 1; T <= n; T++)
        for(int i = 1; i <= n/T; i++)
            g[T] = (g[T] + 1ll * i * c[i*T] % mod) % mod;
}

int main() {
    scanf("%d", &m);
    ll res = 0;
    for(int i = 1, x; i <= m; i++) {
        scanf("%d", &x);
        c[x]++;
        res = (res - x%mod + mod) % mod; 
        n = max(n, x);
    }
    prework();

    for(int i = 1; i <= n; i++) {
        ll x = (f[i] - f[i-1] + mod) % mod;
        res = (res + x * g[i] % mod * g[i] % mod) % mod;
    }

    printf("%lld\n", res * inv2 % mod);
    return 0;
}