折半/二进制警报器 学习笔记

· · 算法·理论

好像两百年没写学习笔记了?感觉比较新的东西记一个。

并不一定更好的阅读体验。

引入

折半/二进制警报器主要可以解决这样的问题:

有两种操作:

  1. 给某个集合加上一个值。如果有满足了某个询问就输出满足了哪些询问

  2. 询问一个集合里的数之和在第多少次操作后第一次大于一个值 V_i

集合大小比较小,大概是常数量级,假设是 k\le 20n,q\le 2\times 10^5

我们当然是可以直接整体二分的。但是如果我们的问题要求在线,那就只能用警报器了,感觉以后有一部分的整体二分都可以改成这个东西啊!

做法

折半警报器

首先我们会一个朴素的暴力:每次询问把集合内的数都挂一个询问的记号,每次一个数增加就暴力 check 询问是否合法,显然我构造一个每次加一就变成 q^2 了,很完蛋。

我们考虑一个像乱搞一样的东西,首先我们发现回答一个询问的必要条件是,存在一个数 \ge \frac {V_i} k,那么我们就给集合里的数都塞一个 \frac {V_i} k 的警报器,当某个数增加之后到了这个阈值,我们就令其“报警”,暴力计算这个警报器对应集合是否合法,合法即成立,否则我们记目前集合的和为 s,那么我们再重新把集合里都塞一个 \frac{V_i-s}{k} 的警报器,重复做直到询问被回答。

我们分析一下复杂度,注意到复杂度来源于我们每次“报警”需要重构,所以我们这里是 O(k) 的复杂度,那么会被重构多少次呢?注意到每次“报警”之后 V\to \frac {k-1} k V,所以最多“报警”O(\log_{\frac{k}{k-1}}V) = O(\frac{\ln V}{\ln k - \ln{k-1}}) = O(k\log V)。同时我们需要用一个堆对于每个元素维护最小的警报取出来,所以总复杂度是 O(k^2\log V\log q)

模板。

二进制警报器

由 zak 发明,挂一个原文的 blog。

我们发现折半警报器有两个很劣的事情,一个是我需要用堆维护报警器,一个是每次报警我需要 O(k) 的复杂度重构一下。我们大概想的是我们有一个较能整体一起进行的方式。比如对于第一个问题,我们就需要我们的警报是一些固定的值,这样就很方便我们用桶排而不用堆进行维护。对于第二个问题,我们需要一种方式能限制一下我们有效报警,也就是需要重构的次数。

我们考虑用二进制报警器,具体的,对于每一个集合我们维护一个阈值 h,当集合中有一个数 a_i\to a_i + w 的时候,如果 (a_i,a_i+w] 跨过了一个 2^h 的倍数,那么我们就报警。如果当前的 h 能在集合内的数都不报警的情况下到达我们需要的值时,我们令 h 减一。

考虑分析一下这样做的复杂度,对于 h=h_0,因为 V_i - \sum a 不能在全都不报警的情况下达到,所以是 O(k2^{h_0}) 的,而对于每个数,报警两次一定会增加 2^{h_0},所以在 h=h_0 时会报警至多 k 次。报警总次数就是 k\log V 的。

这样还有一个好处是我们只有至多 O(\log V) 个值需要记录,可以不用堆直接桶记录即可,优化掉了 \log q

但是现在我们还是没有优化掉重构的次数。注意到我们在 h 降到 -1 的时候就等于答案了,我们可以动态记录一个值 lim,代表在 h=h_0 时,我至多在不触发警报的时候满足多少,在每次 h\to h-1 的时候重构这个 lim 即可,这样我们在一次警报的时候就没有必要暴力重构,而是检验这个 lim 是否合法即可。

这样我们就可以把这类问题优化到 O(qk\log V),同时常数极小,唯一的缺点在于,这个方法的空间复杂度是 O(nk\log V) 的,对比整体二分还是劣了不少。

可以用二进制警报器的方式再写一下模板,放一个我的代码(因为我比较偷懒,所以模板包括下一个例题都写的是 k^2\log V,最后一个是严格的 k\log V)。

#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 1e5 + 5;
int n, m, lab[maxn], lpf[maxn];
int prime[maxn], tot, vis[maxn], lim[maxn], x[maxn];
void prepare() {
    vis[0] = vis[1] = 1;
    for (int i = 2; i <= n; i++) {
        if(!vis[i])
            prime[++tot] = i, lpf[i] = i;
        for (int j = 1; j <= tot && prime[j] * i <= n; j++) {
            vis[i * prime[j]] = 1;
            lpf[i * prime[j]] = prime[j];
            if(i % prime[j] == 0)
                break;
        }
    }
}
int idx;
vector<int> pos[maxn];
vector<int> get_pos(int x) {
    vector<int> t;
    int lst = 0;
    while(x != 1) {
        if(lst != lpf[x])
            t.push_back(lpf[x]), lst = lpf[x], x /= lst;
        else
            x /= lst;
    } 
    return t;
}
vector<int> t;
bool cross(int l, int r, int v) {
    return r / v - l / v;
}
struct node {
    int p, val;
    friend bool operator<(node x, node y) {
        return x.val > y.val;
    }
} ;
int use[maxn];
vector<int> q[64][maxn];
int a[maxn];
queue<int> tp;
int get_nxt(int x, int y) {
    return (x / y + 1) * y;
}
vector<int> ans;
void renew(int p, int v) {
    vector<int> res;
    if(v == 0)
        return ;
    for (int i = 0; i < 64; i++) {
        if(cross(a[p], a[p] + v, (1ll << i))) {
            for (int j = 0; j < q[i][p].size(); j++) {
                int post = q[i][p][j];
            //  cout << lab[post] << " " << post << " " << i << endl;
                if(lab[post] != i || use[post])
                    continue;
                int rest = v;
                for (int i = 0; i < pos[post].size(); i++) {
                    int nwp = pos[post][i];
                    rest += a[nwp];
                }
                if(rest >= lim[post]) {
                    ans.push_back(post);
                    use[post] = 1;
                }
                else 
                    tp.push(post);
            }
            q[i][p].clear();
        }
    }
    a[p] += v;
    while(!tp.empty()) {
        int t = tp.front(); tp.pop();
        int bg = lab[t];
        while(1) {
            int r = 0;
            for (int i = 0; i < pos[t].size(); i++)
                r += get_nxt(a[pos[t][i]], (1ll << lab[t])) - 1;
            if(r >= lim[t]) 
                lab[t]--;
            else
                break;
        }
        if(bg != lab[t]) {
            for (int i = 0; i < pos[t].size(); i++)
                q[lab[t]][pos[t][i]].push_back(t);
        }
        else 
            q[lab[t]][p].push_back(t);
    }
}
signed main() {
    cin >> n >> m;
    prepare();
    int lst = 0;
    vector<int> rm;
//  cout << get_nxt(0, (1ll << 61)) << endl;
    while(m--) {
        int op, x, y; cin >> op >> x >> y;
        y ^= lst;
        if(op == 0) {
            t = get_pos(x);
            ans.clear();
            for (int i = 0; i < t.size(); i++) 
                renew(t[i], y);
            for (int i = 0; i < rm.size(); i++)
                ans.push_back(rm[i]);
            rm.clear();
        //  cout << "adfafasdgag";
            sort(ans.begin(), ans.end());
            cout << ans.size() << " ";
            for (int i = 0; i < ans.size(); i++)
                cout << ans[i] << " ";
            cout << endl;
        //  cout << lab[1] << endl;
            lst = ans.size();
            ans.clear();
        }
        else {
            t = get_pos(x);
            idx++;
            pos[idx] = t;
            lab[idx] = 51;
            if(!y) {
                rm.push_back(idx);
                continue;
            }
            lim[idx] = y;
            for (int i = 0; i < pos[idx].size(); i++)
                lim[idx] += a[pos[idx][i]];
            int r = 0;
            while(1) {
                int r = 0;
                for (int i = 0; i < pos[idx].size(); i++)
                    r += get_nxt(a[pos[idx][i]], (1ll << lab[idx])) - 1;
                if(r >= lim[idx]) 
                    lab[idx]--;
                else
                    break;
            }
            for (int i = 0; i < pos[idx].size(); i++)
                q[lab[idx]][pos[idx][i]].push_back(idx);
        }
    }
    return 0;
}

例题

qoj7415 Fast Spanning Tree

deepseek 因为我复制的时候小于等于号被吃掉了导致看不懂题,被硬控了 20min。

题意:给出若干条带权边,每个点有点权,我们一开始有一个完全空的图,重复做以下操作。

做法: 其实我们警报器做的问题都自带一点在线属性,比如这个题就是隐藏了一个在线的事情。 每次就等于合并两个连通块然后把他们的警报器合并以下即可,其他倒是没有什么大的区别。当然你这个警报器直接合并,用启发式可以合并但是太蠢了,而且空间这题是 256 MB,这样常数嘎嘎大,我们可以改成链表维护,这样就可以 $O(1)$ 简单维护。 折半这个题也是可以做,稍微常数大一点多一个 $\log$。 有点轻微卡空间。 代码: ``` #include <bits/stdc++.h> using namespace std; const int maxn = 3e5 + 5; int n, w[maxn], m, x[maxn], y[maxn], v[maxn], er[maxn]; vector<int> lim[maxn]; int pre[maxn], sz[maxn]; int fnd(int x) { return (pre[x] == x ? x : pre[x] = fnd(pre[x])); } void prepare() { for (int i = 1; i <= n; i++) pre[i] = i, sz[i] = 1; } int h[maxn]; int get_nxt(int x, int v) { return (x / v + 1) * v; } bool cross(int l, int r, int v) { return r / v - l / v; } set<int> s; struct node { int p, nxt; } tr[maxn * 50]; int tot; struct My_list { int f, t; void push_back(int x) { if(tot % 100000 == 0) cerr << tot << endl; tr[++tot] = node{x, 0}; if(!f) f = tot; else tr[t].nxt = tot; t = tot; } friend My_list operator+(My_list x, My_list y) { if(!x.f) return y; if(!y.f) return x; tr[x.t].nxt = y.f; x.t = y.t; return x; } void clear() { f = t = 0; } } vec[maxn][26]; int use[maxn]; queue<int> q; void renew(int p, int vt) { if(vt == 0) return ; for (int i = 0; i <= 20; i++) { if(cross(w[p], w[p] + vt, (1ll << i))) { // cout << i << endl; for (int t = vec[p][i].f; t; t = tr[t].nxt) { int pos = tr[t].p; // cout << pos << "adslksajg" << " " << h[pos] << " " << i << " " << use[pos] << endl; if(use[pos] || h[pos] != i) continue; if(w[fnd(x[pos])] + w[fnd(y[pos])] + vt >= v[pos]) s.insert(pos), use[pos] = 1; else q.push(pos); } vec[p][i].clear(); } } w[p] += vt; w[p] = min(w[p], 1000000); while(!q.empty()) { int i = q.front(); q.pop(); if(fnd(x[i]) == fnd(y[i])) continue; int bg = h[i]; // cout << i << "aslaskjg" << " " << << endl; while(1) { int r = get_nxt(w[fnd(x[i])], 1ll << h[i]) + get_nxt(w[fnd(y[i])], 1ll << h[i]) - 2; if(r >= v[i]) h[i]--; else break; if(h[i] < 0) cerr << "Wrong" << " " << w[fnd(x[i])] << " " << w[fnd(y[i])] << " " << v[i] << " " << r << endl; } if(bg != h[i] || fnd(x[i]) == p) vec[fnd(x[i])][h[i]].push_back(i); if(bg != h[i] || fnd(y[i]) == p) vec[fnd(y[i])][h[i]].push_back(i); } } void unn(int x, int y) { x = fnd(x); y = fnd(y); if(x == y) { return ; } if(sz[x] > sz[y]) swap(x, y); for (int i = 0; i < lim[x].size(); i++) lim[y].push_back(lim[x][i]); int t = w[y]; renew(y, w[x]); renew(x, t); for (int i = 0; i <= 25; i++) vec[y][i] = vec[x][i] + vec[y][i]; pre[x] = y, sz[y] += sz[x]; } vector<int> ans; signed main() { cin >> n >> m; for (int i = 1; i <= n; i++) cin >> w[i]; for (int i = 1; i <= m; i++) { cin >> x[i] >> y[i] >> v[i]; if(w[x[i]] + w[y[i]] >= v[i]) { s.insert(i); use[i] = 1; continue; } lim[x[i]].push_back(i), lim[y[i]].push_back(i); } prepare(); for (int i = 1; i <= m; i++) { if(use[i]) continue; h[i] = 25; //cout << "adf" << endl; while(1) { int r = get_nxt(w[fnd(x[i])], 1ll << h[i]) + get_nxt(w[fnd(y[i])], 1ll << h[i]) - 2; // cout << r << " " << h[i] << " " << v[i] << " " << get_nxt(w[fnd(x[i])], 1ll << h[i]) << endl; if(r >= v[i]) h[i]--; else break; } vec[fnd(x[i])][h[i]].push_back(i); vec[fnd(y[i])][h[i]].push_back(i); } while(s.size()) { // for (set<int>::iterator it = s.begin(); it != s.end(); it++) // cout << *it << " "; // cout << endl; int pos = *s.begin(); s.erase(pos), use[pos] = 1; if(fnd(x[pos]) == fnd(y[pos])) continue; ans.push_back(pos); if(ans.size() % 1000 == 0) cerr << 123 << endl; unn(x[pos], y[pos]); //cout << pos << " " << h[1] << " " << w[fnd(y[2])] + w[fnd(x[2])] << endl; } cout << ans.size() << endl; for (int i = 0; i < ans.size(); i++) cout << ans[i] << " "; cout << endl; return 0; } /* 5 5 1 1 1 1 1 1 2 1 2 3 5 3 4 2 4 5 2 2 3 2 */ ``` ## [qoj8035 Call Me Call Me](https://qoj.ac/problem/8035) 题意:有 $n$ 个人,每个人只有当 $[l,r]$ 中的人中有 $k$ 个接到电话后才会愿意接电话,问你最多能给多少个人打电话,$n\le 4\times 10^5$。 做法: 我们看到之前题目的集合大小都很小,这个题的集合大小很大是 $O(n)$,没办法直接做,但是我们看到区间,很自然地可以把 $[l,r]$ 拍到线段树区间上,直接二进制警报器维护可以做到 $O(n\log ^2n)$,空间相同,稍微有点小卡。有一个不是很 pratical 的做法是,我们想减少一点这个区间个数,可以用多叉树维护,这样就可以拆成 $O(\frac {\log n}{\log\log n})$ 个区间,可以优化一点理论复杂度但是常数会比较大。 但是!让我们回头来研究一下折半的事情。我们直接做折半确实很完蛋,是四只 $\log$ 的。我们考虑猫树一下,把每个区间只分成 $mid$ 的前半段和后半段两个区间。对于一个修改,我们影响的是猫树区间前后缀的事情,我们可以用线段树稍微维护一下每个猫树区间前后缀中警报器最小值,同样也可以做到 $2\log$,会带一些常数。 顺提一嘴这个题的其他做法但是不是本文的重点就随便说说:考虑,如果我一次加入 $B$ 个操作,那么就只有那些要求人数 $\le B$ 的才有可能被确定成合法的,所以我们每次加入 $B$ 个操作,前缀和暴力计算一下目前可能合法的这些区间,每次对于一个点标为合法,就暴力去对于所有包含他的区间去判断够不够,这里可以用挂在线段树上的方式快速找出包含一个点的所有区间,因为每个区间最多被暴力确定 $B$ 次,所以平衡一下就是 $n\sqrt n$ 的,因为空间是线性的加上常数很小,跑得比警报器快一万倍。 从这个题大概就可以看得出折半和二进制的分别的优势,二进制的比较快,但是折半更具有拓展性,可以用很多数据结构去维护一下。 给一个二进制警报器的代码,这份应该是比较标准的: ``` #include <bits/stdc++.h> using namespace std; const int maxn = 4e5 + 5; int n, l[maxn], r[maxn], k[maxn], lim[maxn], h[maxn], s[maxn]; vector<int> pos[maxn]; void get_pos(int l, int r, int x, int y, int t, int p) { if(x <= l && r <= y) { pos[p].push_back(t); return ; } int mid = l + r >> 1; if(x <= mid) get_pos(l, mid, x, y, t << 1, p); if(mid < y) get_pos(mid + 1, r, x, y, t << 1 | 1, p); } int get_nxt(int x, int v) { return (x / v + 1) * v; } bool cross(int l, int r, int v) { return r / v - l / v; } vector<int> vec[maxn << 2][21]; int val[maxn << 2], use[maxn]; queue<int> q; void rebuild(int x) { // if(x == 3) // cerr << h[x] << "debug" << endl; while(h[x] >= 0) { lim[x] = 0; for (int j = 0; j < pos[x].size(); j++) { lim[x] += get_nxt(val[pos[x][j]], 1 << h[x]) - 1; if(lim[x] >= k[x]) break; } if(lim[x] >= k[x]) h[x]--; else break; } if(h[x] < 0) { q.push(x); use[x] = 1; // cerr << x << " asdf" << endl; return ; } // if(x == 3) // cerr << h[x] << "debug" << endl; for (int j = 0; j < pos[x].size(); j++) vec[pos[x][j]][h[x]].push_back(x); } int ans = 0; queue<int> ers; void renew(int x, int v) { //cout << x << " " << v << endl; for (int i = 0; i <= 20; i++) { if(cross(val[x], val[x] + v, (1 << i))) { for (int j = 0; j < vec[x][i].size(); j++) { int p = vec[x][i][j]; if(use[p] || h[p] != i) continue; ers.push(p); lim[p] -= get_nxt(val[x], 1 << i) - 1; lim[p] += get_nxt(val[x] + v, 1 << i) - 1; // if(p == 3) // cerr << x << " " << v << " " << i << " " << lim[p] << " " << l[p] << " " << r[p] << endl; } vec[x][i].clear(); } } val[x] += v; while(!ers.empty()) { int p = ers.front(); ers.pop(); if(lim[p] >= k[p]) rebuild(p); else vec[x][h[p]].push_back(p); } } void modify(int l, int r, int t, int pos) { renew(t, 1); if(l == r) return ; int mid = l + r >> 1; if(pos <= mid) modify(l, mid, t << 1, pos); else modify(mid + 1, r, t << 1 | 1, pos); } signed main() { cin >> n; for (int i = 1; i <= n; i++) cin >> l[i] >> r[i] >> k[i], get_pos(1, n, l[i], r[i], 1, i); for (int i = 1; i <= n; i++) if(!k[i]) q.push(i), use[i] = 1; for (int i = 1; i <= n; i++) { if(!k[i]) continue; h[i] = 20; while(1) { lim[i] = 0; for (int j = 0; j < pos[i].size(); j++) { lim[i] += get_nxt(val[pos[i][j]], 1 << h[i]) - 1; if(lim[i] >= k[i]) break; } if(lim[i] >= k[i]) h[i]--; else break; } for (int j = 0; j < pos[i].size(); j++) vec[pos[i][j]][h[i]].push_back(i); } while(!q.empty()) { ans++; int p = q.front(); q.pop(); // cerr << p << endl; modify(1, n, 1, p); } cout << ans << endl; return 0; } /* 5 1 3 2 2 4 2 1 5 1 4 5 1 1 5 0 */ ```