浅谈网格图分治

· · 算法·理论

原理

网格图分治,顾名思义,是在网格图结构上进行分治的一类算法,常用于处理多次询问两点最短路的问题。

来看一个板子:P3350 [ZJOI2016] 旅行者。

给出一个 n \times m 的无向带权网格图。q 次询问从 (x_1,y_1)(x_2,y_2) 的最短路。

基本思想

将所有询问离线后进行分治处理。对于当前考虑的子网格:

::::info[图示]{open}

考虑下面这个例子,红色为起点,绿色为终点。

沿蓝色虚线将网格“撕开”,并在中线上更新答案:

::::

这样一来:

复杂度分析

根据均值不等式,上述算法的最劣时间复杂度在网格的长和宽同阶时取得。因此不妨将网格视为边长为 n 的正方形进行分析,此时递推式为

T(n)=O(n^3 \log n)+4T(n/2)

由主定理得 T(n)=O(n^3 \log n)

更一般地,设网格图的总规模为 S,询问数量为 q,则该分治算法的整体时间复杂度为 O\left( (S+q) \sqrt{S} \log S\right)。而如果最短路部分可以加速到线性,复杂度可以优化到 O\left((S+q) \sqrt{S}\right)

代码

::::info[代码]

const int N = 1e5 + 5, inf = 1e9;

int n, m, q, ans[N];
using query = tuple<int, int, int, int, int>;  // {x1, y1, x2, y2, id}
vector<query> qry;
// L, R, U, D
const int dirx[] = {0, 0, 0, -1, 1};
const int diry[] = {0, -1, 1, 0, 0};
vector<vector<int>> dis, e[5];

struct node {
    int x, y, d;
    node(int a = 0, int b = 0, int c = 0) : x(a), y(b), d(c) {}
    bool operator< (const node& a) const {return d > a.d;}
};

void dijkstra(int lx, int rx, int ly, int ry, int sx, int sy, const vector<query>& q) {
    if (dis.empty()) dis.resize(n + 5, vector<int>(m + 5, 0));
    for (int x = lx; x <= rx; x++) {
        for (int y = ly; y <= ry; y++) dis[x][y] = inf;
    }
    priority_queue<node> h;
    dis[sx][sy] = 0, h.emplace(sx, sy, 0);
    while (!h.empty()) {
        auto [x, y, dl] = h.top(); h.pop();
        if (dl != dis[x][y]) continue;
        for (int d = 1; d <= 4; d++) {
            int nx = x + dirx[d], ny = y + diry[d];
            if (nx < lx || nx > rx || ny < ly || ny > ry) continue;
            if (dis[nx][ny] > dis[x][y] + e[d][x][y]) dis[nx][ny] = dis[x][y] + e[d][x][y], h.emplace(nx, ny, dis[nx][ny]);
        }
    }
    for (const auto& [x1, y1, x2, y2, id] : q) ans[id] = min(ans[id], dis[x1][y1] + dis[x2][y2]);
}

void solve(int lx, int rx, int ly, int ry, vector<query> q) {
    if (lx > rx || ly > ry || q.empty()) return;
    if (rx - lx > ry - ly) {
        int mid = (lx + rx) >> 1;
        for (int y = ly; y <= ry; y++) dijkstra(lx, rx, ly, ry, mid, y, q);
        vector<query> ql, qr;
        for (const auto& [x1, y1, x2, y2, id] : q) {
            if (x1 < mid && x2 < mid) ql.emplace_back(x1, y1, x2, y2, id);
            if (mid < x1 && mid < x2) qr.emplace_back(x1, y1, x2, y2, id);
        }
        solve(lx, mid - 1, ly, ry, ql), solve(mid + 1, rx, ly, ry, qr);
    } else {
        int mid = (ly + ry) >> 1;
        for (int x = lx; x <= rx; x++) dijkstra(lx, rx, ly, ry, x, mid, q);
        vector<query> ql, qr;
        for (const auto& [x1, y1, x2, y2, id] : q) {
            if (y1 < mid && y2 < mid) ql.emplace_back(x1, y1, x2, y2, id);
            if (mid < y1 && mid < y2) qr.emplace_back(x1, y1, x2, y2, id);
        }
        solve(lx, rx, ly, mid - 1, ql), solve(lx, rx, mid + 1, ry, qr);
    }
}

void _main() {
    cin >> n >> m;
    for (int d = 1; d <= 4; d++) e[d].resize(n + 5, vector<int>(m + 5, 0));
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j < m; j++) cin >> e[2][i][j], e[1][i][j + 1] = e[2][i][j];
    }
    for (int i = 1; i < n; i++) {
        for (int j = 1; j <= m; j++) cin >> e[4][i][j], e[3][i + 1][j] = e[4][i][j];
    }
    cin >> q;
    for (int i = 1, x1, y1, x2, y2; i <= q; i++) cin >> x1 >> y1 >> x2 >> y2, qry.emplace_back(x1, y1, x2, y2, i);
    fill(ans + 1, ans + q + 1, inf);
    solve(1, n, 1, m, qry);
    for (int i = 1; i <= q; i++) cout << ans[i] << '\n';
} 

::::

应用

P9040 [PA 2021] Desant 2

给定一个长为 n 的序列 a。有 q 次询问,对区间 [l,r] 求解如下问题:

对区间 [l,r],将其划分为若干不相交连续段,每段长度恰好为 k,允许最多一个元素不属于任何段,最大化各段元素和之和。

:::::success[题解] 记 p_ia_i 的前缀和,设 f_i 为考虑到第 i 个元素的最大得分,容易写出 DP 转移:

f_i =\max(f_{i-1},f_{i-k}+s_i-s_{i-k})

对于每次询问,令 f_{l-1} \gets 0 然后跑 DP,视 n,q 同阶,复杂度 O(n^2)

如果直接用矩阵优化可以做到 O(nk^3),猫树分治可以做到 O(nk \log n),和我一样没有前途。

将上述 DP 转移视为一个自动机,可以建成如下 DAG:

这样,问题等价于:询问从 l-1r 的最长路。

我们观察一下这个图的形态:

(借用了这篇题解里的图)

这个 DAG 的结构相当于一个 k \times \left\lfloor n/k\right \rfloor 的网格图加上了若干斜边,自然想到网格图分治。

考虑一下怎么处理斜边。注意到这些斜边只会在第一次按行分治时起作用,因此可以在这一层单独暴力处理斜边,不影响整体复杂度。

我们可以用简单的 DP 跑出中线向两侧的最短路,复杂度就是 O((n+q)\sqrt{n}) 的。

::::info[代码]

const int N = 6e5 + 5;
const int64 inf = 1e18;

int n, k, q;
int64 a[N], ans[N], dis[N];
using query = tuple<int, int, int, int, int>;  // {x1, y1, x2, y2, id}
vector<query> qry;

void update(int lx, int rx, int ly, int ry, int sx, int sy, const vector<query>& q) {
    int s = sx * k + sy;
    for (int x = lx; x <= rx; x++) {
        for (int y = ly; y <= ry; y++) dis[x * k + y] = -inf;
    }
    dis[s] = 0;
    auto chk = [&](int i) -> bool {return lx <= i / k && i / k <= rx && ly <= i % k && i % k <= ry;};
    for (int x = sx; x >= lx; x--) {
        for (int y = ry; y >= ly; y--) {
            int i = x * k + y;
            if (i > s) continue;
            if (chk(i + 1)) dis[i] = max(dis[i], dis[i + 1]);
            if (chk(i + k)) dis[i] = max(dis[i], dis[i + k] + a[i + k] - a[i]);
        }
    }
    for (int x = sx; x <= rx; x++) {
        for (int y = ly; y <= ry; y++) {
            int i = x * k + y;
            if (i < s) continue;
            if (chk(i + 1)) dis[i + 1] = max(dis[i + 1], dis[i]);
            if (chk(i + k)) dis[i + k] = max(dis[i + k], dis[i] + a[i + k] - a[i]);
        }
    }
    for (const auto& [x1, y1, x2, y2, id] : q) {
        if (x1 * k + y1 > s || x2 * k + y2 < s) continue;
        ans[id] = max(ans[id], dis[x1 * k + y1] + dis[x2 * k + y2]);
    }
}

void solve(int lx, int rx, int ly, int ry, vector<query> q) {
    if (lx > rx || ly > ry || q.empty()) return;
    if (rx - lx > ry - ly) {
        int mid = (lx + rx) >> 1;
        for (int y = ly; y <= ry; y++) update(lx, rx, ly, ry, mid, y, q);
        vector<query> ql, qr;
        for (const auto& [x1, y1, x2, y2, id] : q) {
            if (x1 < mid && x2 < mid) ql.emplace_back(x1, y1, x2, y2, id);
            if (mid < x1 && mid < x2) qr.emplace_back(x1, y1, x2, y2, id);
        }
        solve(lx, mid - 1, ly, ry, ql), solve(mid + 1, rx, ly, ry, qr);
    } else {
        int mid = (ly + ry) >> 1;
        if (ly == 0 && ry == k - 1) {  // 斜边
            for (int x = lx; x <= rx; x++) update(lx, rx, ly, ry, x, 0, q);
        }
        for (int x = lx; x <= rx; x++) update(lx, rx, ly, ry, x, mid, q);
        vector<query> ql, qr;
        for (const auto& [x1, y1, x2, y2, id] : q) {
            if (y1 < mid && y2 < mid) ql.emplace_back(x1, y1, x2, y2, id);
            if (mid < y1 && mid < y2) qr.emplace_back(x1, y1, x2, y2, id);
        }
        solve(lx, rx, ly, mid - 1, ql), solve(lx, rx, mid + 1, ry, qr);
    }
}

void _main() {
    cin >> n >> k >> q;
    for (int i = 1; i <= n; i++) cin >> a[i], a[i] += a[i - 1];
    for (int i = 1, l, r; i <= q; i++) {
        cin >> l >> r, l--;
        qry.emplace_back(l / k, l % k, r / k, r % k, i);
    }
    solve(0, n / k, 0, k - 1, qry);
    for (int i = 1; i <= q; i++) cout << ans[i] << '\n';
}

:::: :::::

P8182 「EZEC-11」雪的魔法

给定长度为 n 的序列 a,有 q 次询问。

一次操作可选择一个长度不超过 k 的子区间,将其中所有数减一。

求将区间 [l,r] 内所有数变为 0 的最小操作次数。

:::::success[题解] 这种问题笔者称之为“铲雪题”。我正在编写一篇叫《十二重铲雪法》的文章,会把这个题收录进去。下面直接给出本题的等价问题:

将区间中每个位置赋一个权值 w_i \in \{-1,0,1\},满足任意长度不超过 k 的子区间的权值和不超过 1。最大化 \sum a_i w_i

::::info[为什么] 前置知识:线性规划、对偶原理。

\mathcal I 表示所有路径,对每条路径 S 给出一个非负整数 x_S 表示 S 被经过的次数,原问题即:

\begin{aligned} \min & \sum _{S \in \mathcal{I}} x_S\\ \text{subject to } & \sum_{u \in S} x_S=a_u, \\ & x_S \ge 0. \end{aligned}

拆成不等式以得到标准形式:

\begin{aligned} \min & \sum _{S \in \mathcal{I}} x_S\\ \text{subject to } & \sum_{u \in S} x_S \ge a_u, \\ & \sum_{u \in S} -x_S \ge -a_u, \\ & x_S \ge 0. \end{aligned}

得到对偶问题

\begin{aligned} \max & \sum _{u} a_u(p_u-q_u)\\ \text{subject to } & \sum_{u \in S} (p_u-q_u) \le 1, \\ & p_u,q_u \ge 0. \end{aligned}

换元,令 w_u=p_u-q_u,问题变为

\begin{aligned} \max & \sum _{u} a_uw_u\\ \text{subject to } & \sum_{u \in S} w_u \le 1. \end{aligned}

用自然语言表述这个问题,相当于找到一组整数 w_u,对所有满足条件的路径,w_u 之和不大于 1,最大化 a_uw_u 的和。

显然 w_u \le 1。同时对 w_u \le -2,总能调整法找到所有 w_u \ge -1 的解。因此 w_u \in \{-1,0,1\} ::::

考虑 k \ge n 的情况,此时原问题等价于积木大赛,答案即所有正差分之和,也即 \sum \max(0,a_i-a_{i-1})

我们称:按照积木大赛得到的答案是平凡的。对于对偶问题,这样的答案会偏小,因为对偶问题的限制是相对松的。

f_i 为考虑到第 i 个位置的答案,初始 f_l=a_l。对于 i=l+1,l+2,\cdots,r,有如下转移方程:

f_i=\max(f_{i-1}+\max(0,a_i-a_{i-1}),f_{i-k}+a_i)

其中,f_{i-1}+\max(0,a_i-a_{i-1})平凡的。而 f_{i-k}+a_i 的意思是将 [i-k+1,i-1] 中的权值全部置 0,并且令 w_i=1

这个 DP 的正确性不是很显然。但是仔细思考一下,对于一个不平凡的情况,如果 [i-k+1,i-1] 中存在非零权值,那么这总能规约到一个平凡情况。因此,我们得到了一个 O(nq) 做法。

观察一下这个 DP 式子,继续使用网格图分治优化 DP 即可。唯一不同的点在于起点 l 带了点权,加上去即可。复杂度 O((n+q) \sqrt{n})

::::info[代码]

const int N = 2e5 + 5;
const int64 inf = 1e18;

int n, k, q;
int64 a[N], ans[N], dis[N];
using query = tuple<int, int, int, int, int>;  // {x1, y1, x2, y2, id}
vector<query> qry;

void update(int lx, int rx, int ly, int ry, int sx, int sy, const vector<query>& q) {
    int s = sx * k + sy;
    for (int x = lx; x <= rx; x++) {
        for (int y = ly; y <= ry; y++) dis[x * k + y] = -inf;
    }
    dis[s] = 0;
    auto chk = [&](int i) -> bool {return lx <= i / k && i / k <= rx && ly <= i % k && i % k <= ry;};
    for (int x = sx; x >= lx; x--) {
        for (int y = ry; y >= ly; y--) {
            int i = x * k + y;
            if (i > s) continue;
            if (chk(i + 1)) dis[i] = max(dis[i], dis[i + 1] + max(0LL, a[i + 1] - a[i]));
            if (chk(i + k)) dis[i] = max(dis[i], dis[i + k] + a[i + k]);
        }
    }
    for (int x = sx; x <= rx; x++) {
        for (int y = ly; y <= ry; y++) {
            int i = x * k + y;
            if (i < s) continue;
            if (chk(i + 1)) dis[i + 1] = max(dis[i + 1], dis[i] + max(0LL, a[i + 1] - a[i]));
            if (chk(i + k)) dis[i + k] = max(dis[i + k], dis[i] + a[i + k]);
        }
    }
    for (const auto& [x1, y1, x2, y2, id] : q) {
        if (x1 * k + y1 > s || x2 * k + y2 < s) continue;
        ans[id] = max(ans[id], a[x1 * k + y1] + dis[x1 * k + y1] + dis[x2 * k + y2]);
    }
}

void solve(int lx, int rx, int ly, int ry, vector<query> q) {
    if (lx > rx || ly > ry || q.empty()) return;
    if (rx - lx > ry - ly) {
        int mid = (lx + rx) >> 1;
        for (int y = ly; y <= ry; y++) update(lx, rx, ly, ry, mid, y, q);
        vector<query> ql, qr;
        for (const auto& [x1, y1, x2, y2, id] : q) {
            if (x1 < mid && x2 < mid) ql.emplace_back(x1, y1, x2, y2, id);
            if (mid < x1 && mid < x2) qr.emplace_back(x1, y1, x2, y2, id);
        }
        solve(lx, mid - 1, ly, ry, ql), solve(mid + 1, rx, ly, ry, qr);
    } else {
        int mid = (ly + ry) >> 1;
        if (ly == 0 && ry == k - 1) {  // 斜边
            for (int x = lx; x <= rx; x++) update(lx, rx, ly, ry, x, 0, q);
        }
        for (int x = lx; x <= rx; x++) update(lx, rx, ly, ry, x, mid, q);
        vector<query> ql, qr;
        for (const auto& [x1, y1, x2, y2, id] : q) {
            if (y1 < mid && y2 < mid) ql.emplace_back(x1, y1, x2, y2, id);
            if (mid < y1 && mid < y2) qr.emplace_back(x1, y1, x2, y2, id);
        }
        solve(lx, rx, ly, mid - 1, ql), solve(lx, rx, mid + 1, ry, qr);
    }
}

void _main() {
    cin >> n >> k >> q;
    for (int i = 1; i <= n; i++) cin >> a[i];
    for (int i = 1, l, r; i <= q; i++) {
        cin >> l >> r;
        qry.emplace_back(l / k, l % k, r / k, r % k, i);
    }
    solve(0, n / k, 0, k - 1, qry);
    for (int i = 1; i <= q; i++) cout << ans[i] << '\n';
} 

::::

:::::