莫队

· · 个人记录

莫队基础

对于一段区间 [l,r],我们可以把它看为从 [l-1,r] - a_{l-1}[l,r + 1] - a_{r+1}[l+1,r] + a_l[l,r - 1] + a_r 得来的。

相当于两个指针定义范围,通过不断向外扩张和向内收缩得到区间答案。

而这样做 m 次查询的时间复杂度是很大的,考虑分块优化。

时间复杂度

先将数组分为长度为 \sqrt{n} 的块,按左端点所在块升序排序,若左端点在同一个块内,则按右端点升序排序。

  1. 左指针移动次数

    • 左端点跨块时,从一个块的末尾到下一个块的开头,最多移动 \sqrt{n} 次,总共有 \sqrt{n} 个块,因此跨块移动时间复杂度为 O((\sqrt{n})^2 = O(n)
    • 综上,左指针移动时间复杂度为:O(m\sqrt{n} + n)
  2. 右指针移动次数

    • 当左端点在同一个块内时,右端点是从左到右排序的,所以这个块的查询中,右端点的移动的时间复杂度为 O(n)
    • 总共有 \sqrt{n} 个块,因此右端点总移动次数为 O(n\sqrt{n})

综上,莫队的时间复杂度为 (n + m)\sqrt{n},当 n = m 时,O(n\sqrt{n}),一般题目 n \leq 2 \times 10^5

奇偶优化

在排序上做出改变,若两左端点都在一个奇数块时,右端点升序排序,反之,右端点降序排序。

P3901 数列找不同

维护一个 cnt_x 统计数字 x 当前出现的个数。

互不相同的条件即为数字个数等于当前询问的 r - l + 1

#include <bits/stdc++.h>

using namespace std;

const int N = 1e5 + 10;

int n, m, len, a[N], cnt[N], tot, ans[N];

struct Query {
    int l, r, id;
} q[N];

bool cmp(Query a, Query b) {
    if (a.l / len != b.l / len)
        return a.l < b.l;
    return a.r < b.r;
}

void SUB(int x) {
    (!--cnt[a[x]]) && (tot--);
}

void ADD(int x) {
    (++cnt[a[x]] == 1) && (tot++);
}

signed main() {
    cin >> n >> m;
    len = sqrt(n);
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    } 
    for (int i = 1; i <= m; i++) {
        cin >> q[i].l >> q[i].r;
        q[i].id = i;
    }
    sort (q + 1, q + 1 + m, cmp);
    int l = 1, r = 0;
    for (int i = 1; i <= m; i++) {
        while (l < q[i].l)
            SUB(l++);
        while (l > q[i].l)
            ADD(--l);
        while (r < q[i].r)
            ADD(++r);
        while (r > q[i].r)
            SUB(r--);
        ans[q[i].id] = (tot == q[i].r - q[i].l + 1);
    }
    for (int i = 1; i <= m; i++) {
        cout << (ans[i] ? "Yes" : "No") << '\n';
    } 
    return 0;
} 

P2709 小B的询问

延续上一题,如果是 \operatorname{SUB} 函数就减掉要删的的加上后面的,反之就是先加再减。

#include <bits/stdc++.h>
#define int long long

using namespace std;

const int N = 5e4 + 10;

int n, m, len, a[N], cnt[N], tot, ans[N], k;

struct Query {
    int l, r, id;
} q[N];

bool cmp(Query a, Query b) {
    if (a.l / len != b.l / len)
        return a.l < b.l;
    return a.r < b.r;
}

void SUB(int x) {
    tot -= cnt[a[x]] * cnt[a[x]];
    cnt[a[x]]--;
    tot += cnt[a[x]] * cnt[a[x]];
  return;
}

void ADD(int x) {
    tot -= cnt[a[x]] * cnt[a[x]];
    cnt[a[x]]++;
    tot += cnt[a[x]] * cnt[a[x]];
  return;
}

signed main() {
    cin >> n >> m >> k;
    len = sqrt(n);
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    } 
    for (int i = 1; i <= m; i++) {
        cin >> q[i].l >> q[i].r;
        q[i].id = i;
    }
    sort (q + 1, q + 1 + m, cmp);
    int l = 1, r = 0;
    for (int i = 1; i <= m; i++) {
        while (l < q[i].l)
            SUB(l++);
        while (l > q[i].l)
            ADD(--l);
        while (r < q[i].r)
            ADD(++r);
        while (r > q[i].r)
            SUB(r--);
        ans[q[i].id] = tot;
    }
    for (int i = 1; i <= m; i++) {
        cout << ans[i] << '\n';
    } 
    return 0;
} 

P1494 [国家集训队] 小 Z 的袜子

答案就是 \operatorname{sum}(C^{cnt_c}_{2}) / ((r - l + 1) \times (r - l) / 2)

化简得:(\operatorname{sum}(c_i^2) - (r - l + 1)) / ((r - l + 1) \times (r - l))

#include <bits/stdc++.h>
#define int long long

using namespace std;

const int N = 5e4 + 10;

int n, m, len, a[N], cnt[N], tot, k;
pair<int, int> ans[N];

struct Query {
    int l, r, id;
} q[N];

bool cmp(Query a, Query b) {
    if (a.l / len != b.l / len)
        return a.l < b.l;
    return a.r < b.r;
}

void SUB(int x) {
    tot -= cnt[a[x]] * cnt[a[x]];
    cnt[a[x]]--;
    tot += cnt[a[x]] * cnt[a[x]];
  return;
}

void ADD(int x) {
    tot -= cnt[a[x]] * cnt[a[x]];
    cnt[a[x]]++;
    tot += cnt[a[x]] * cnt[a[x]];
  return;
}

signed main() {
    cin >> n >> m;
    len = sqrt(n);
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    } 
    for (int i = 1; i <= m; i++) {
        cin >> q[i].l >> q[i].r;
        q[i].id = i;
    }
    sort (q + 1, q + 1 + m, cmp);
    int l = 1, r = 0;
    for (int i = 1; i <= m; i++) {
        while (l < q[i].l)
            SUB(l++);
        while (l > q[i].l)
            ADD(--l);
        while (r < q[i].r)
            ADD(++r);
        while (r > q[i].r)
            SUB(r--);
        if (q[i].l == q[i].r) {
            ans[q[i].id] = {0, 1};
        } else {
            int g = __gcd(tot - (q[i].r - q[i].l + 1), (q[i].r - q[i].l + 1) * (q[i].r - q[i].l));
            ans[q[i].id] = {(tot - (q[i].r - q[i].l + 1)) / g, (q[i].r - q[i].l + 1) * (q[i].r - q[i].l) / g};
        }
    }
    for (int i = 1; i <= m; i++) {
        printf("%d/%d\n", ans[i].first, ans[i].second);
    } 
    return 0;
}