决策单调性优化DP
SegmentDog · · 算法·理论
基础知识
四边形不等式
如果一个二元函数
那么称其满足四边形不等式。因为在求
的简便地记为
下面讨论的四边形不等式均为
四边形不等式还有一个等价形式:对于任意
一般证明四边形不等式就利用这种形式或类似这种的形式,这样只需要考虑从
动态规划中我们经常列出
这个式子其实不用怎么特别记为『交叉小于包含』或者相反,你只需要理解它的本质是
决策单调性
对于一般情形的最大化线性 DP,可以表示为
其中
我们称所有的
此时如果当
注意决策单调性只保证了最优决策点是单调的,其它决策点间的关系没有任何保证。
完全单调性
对于任意一组
具有完全单调性的转移显然也具有决策单调性。
蒙日性
蒙日性其实就是四边形不等式性,只不过换了个名字,你想叫它四边形不等式性也没有问题。如果函数
:::info[定理]{open}
具有蒙日性的转移同时具有完全单调性
:::
:::info[证明]{open}
尝试从这个形式入手:
这个形式保证了一件事:如果在
现在考虑完全单调性的定义。任取了一个
:::
算法
当动态规划的转移方程满足上面的部分性质时,就可以利用它来优化转移。
注:下面的算法代码求的都是最大值,最小值只需要把部分符号反向即可。
分治
适用于
要求解所有
于是分治过程中维护最优决策点的上下界,取转移点区间的中点暴力求出其最优决策点,然后向左右两侧分治即可。这样每一层至多调用
:::success[代码]
int calc(int, int);
// l,r 表示转移点区间,L,R 表示决策点区间
void solve(int l, int r, int L, int R) {
if (l > r) return;
int mid = (l + r) >> 1, pos = -1, res = 0;
// 暴力查找 mid 位置的最优决策点
for (int i = L; i <= R && i < mid; i++)
if (!~pos || res < calc(i + 1, mid))
pos = i, res = calc(i + 1, mid);
f[mid] = res;
solve(l, mid - 1, L, pos), solve(mid + 1, r, pos, R);
}
:::
:::info[
于是只需要把 cost 函数改为下面这样:
int cl, cr, cans;
void mov(int, bool);
int calc(int l, int r) {
while (cl > l) mov(--cl, 1);
while (cr < r) mov(++cr, 1);
while (cl < l) mov(cl++, 0);
while (cr > r) mov(cr--, 0);
return cans;
}
:::
二分队列
适用于
我们动态维护一个队列,每个元素形如
查询时把过时的元素(
:::success[代码]
struct QNode { int l, r, v; };
struct Queue {
QNode q[XN]; int hh, tt;
void init() { hh = tt = 0; }
bool em() { return hh == tt; }
QNode& h() { return q[hh]; }
QNode& t() { return q[tt - 1]; }
void pph() { hh++; }
void ppt() { tt--; }
void pst(const QNode& val) { q[tt++] = val; }
} q;
int calc(int, int);
for (int i = 1; i <= n; i++) {
while (q.h().r < i) q.pph();
f[i] = cost(q.h().v, i);
if ((q.h().l = i + 1) > q.h().r) q.pph();
while (!q.em() && cost(q.t().v, q.t().l) <= cost(i, q.t().l)) q.ppt();
if (q.em()) q.pst(QNode{i + 1, n, i});
else if (cost(q.t().v, q.t().r) <= cost(i, q.t().r)) {
int l = q.t().l, r = q.t().r;
while (l < r) {
int mid = (l + r) >> 1;
if (cost(q.t().v, mid) <= cost(i, mid)) r = mid;
else l = mid + 1;
}
if ((q.t().r = l - 1) < q.t().l) q.ppt();
q.pst(QNode{l, n, i});
} else if (q.t().r < n)
q.pst(QNode{q.t().r + 1, n, i});
}
:::
简化 LARSCH 算法
适用于
这是一种分治算法。设当前对转移点区间
在进入函数时,需要求出
我们先找出转移点终点
算法的正确性在上面的过程中已经说明。首先递归层数是
注意:因为函数需要的条件,调用分治函数前,要用最左端先更新一次最右端。
当
:::success[代码]
int f[XN], fr[XN];
int cost(int, int);
void upd(int j, int i) {
if (!~fr[i] || f[i] < cost(j, i))
fr[i] = j, f[i] = cost(j, i);
}
// 记得调用整体函数前要先调用一次 upd(l, r)
void solve(int l, int r, int dep = 0) {
if (l >= r) return;
int mid = (l + r + 1) >> 1;
for (int i = fr[l]; i <= fr[r] && i < mid; i++) upd(i, mid);
if (mid < r) solve(l, mid, dep + 1);
for (int i = l; i < mid; i++) upd(i, r);
if (l < mid) solve(mid, r, dep + 1);
}
:::
这种算法相对于二分队列,虽然可以处理只支持移动访问的情况,但是却不如二分队列灵活,有些题目无法使用。比如下面的第三道例题。
例题
P4767 [IOI 2000] 邮局 加强版
:::info[题目链接]{open}
:::
我们在 dp 时可以放松限制,给一个区间的村庄钦定一个邮局,不要求每个状态下村庄的最小邮局都是给它钦定的邮局,因为这样不优,不会称为最终答案。
首先列出暴力 dp:设
其中
暴力转移复杂度是
:::success[代码]
#include <algorithm>
#include <iostream>
using namespace std;
constexpr int inf = 0x3f3f3f3f;
constexpr int N = 3000, M = 300;
constexpr int XN = N + 10, XM = M + 10;
int n, m, a[XN], w[XN][XN];
int f[XN], g[XN];
void solve(int l, int r, int L, int R) {
if (l > r) return;
int mid = (l + r) >> 1, pos = -1, res = 0;
for (int i = L; i <= R && i < mid; i++)
if (!~pos || res > g[i] + w[i + 1][mid])
pos = i, res = g[i] + w[i + 1][mid];
f[mid] = res;
solve(l, mid - 1, L, pos), solve(mid + 1, r, pos, R);
}
int main() {
ios::sync_with_stdio(false);
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> a[i];
sort(a + 1, a + n + 1);
for (int i = 1; i < n; i++)
w[i][i + 1] = a[i + 1] - a[i];
for (int len = 3; len <= n; len++)
for (int l = 1, r = len; r <= n; l++, r++)
w[l][r] = w[l + 1][r - 1] + a[r] - a[l];
f[0] = 0;
for (int i = 1; i <= n; i++) f[i] = inf;
for (int i = 1; i <= m; i++) {
for (int j = 0; j <= n; j++) g[j] = f[j];
solve(1, n, 0, n);
}
cout << f[n] << endl;
return 0;
}
:::
P5574 [CmdOI2019] 任务分配问题
:::info[题目链接]{open}
给定一个长度为
:::
可以简单地列出朴素动态规划:设
其中
暴力转移无法接受,考虑优化。观察到
证明是显然的:在
于是
:::success[代码]
#include <cstdint>
#include <iostream>
using namespace std;
using ll = int64_t;
template<int N>
struct BIT {
static constexpr int XN = N + 10;
int n, c[XN];
void update(int k, int x) { for (; k <= n; k += k & -k) c[k] += x; }
int query(int k) { int res = 0; for (; k; k -= k & -k) res += c[k]; return res; }
int query(int l, int r) { return query(r) - query(l - 1); }
};
constexpr ll inf = 0x3f3f3f3f;
constexpr int N = 2.5e4;
constexpr int XN = N + 10;
int n, m, a[XN];
BIT<N> bit;
int cl, cr; ll cans;
ll f[XN], g[XN];
void mov(int p, bool side, bool flag) {
if (side) {
if (flag) cans += bit.query(1, a[p] - 1), bit.update(a[p], 1);
else cans -= bit.query(1, a[p] - 1), bit.update(a[p], -1);
} else {
if (flag) cans += bit.query(a[p] + 1, n), bit.update(a[p], 1);
else cans -= bit.query(a[p] + 1, n), bit.update(a[p], -1);
}
}
int calc(int l, int r) {
while (cl > l) mov(--cl, 0, 1);
while (cr < r) mov(++cr, 1, 1);
while (cl < l) mov(cl++, 0, 0);
while (cr > r) mov(cr--, 1, 0);
return cans;
}
void solve(int l, int r, int L, int R) {
if (l > r) return;
int mid = (l + r) >> 1, pos = -1, res = 0;
for (int i = L; i <= R && i < mid; i++)
if (!~pos || res > g[i] + calc(i + 1, mid))
pos = i, res = g[i] + calc(i + 1, mid);
f[mid] = res;
solve(l, mid - 1, L, pos), solve(mid + 1, r, pos, R);
}
int main() {
ios::sync_with_stdio(false);
cin >> n >> m, bit.n = n;
for (int i = 1; i <= n; i++) cin >> a[i];
cl = 1, cr = 0;
for (int i = 0; i <= n; i++)
f[i] = i ? inf : 0;
for (int k = 1; k <= m; k++) {
for (int i = 1; i <= n; i++) g[i] = f[i];
solve(1, n, 0, n);
}
cout << f[n] << endl;
return 0;
}
:::
P10538 [APIO2024] 星际列车
:::info[题目链接]{open}
有
:::
考虑如何设计状态。状态需要包含两个信息:当前在哪个行星,以及当前的时刻。因为每条列车有固定的发车和到站时间,以及固定的起始点,因此我们可以用当前经过的最后一条列车来表示状态。
具体的,设
其中
观察
证:考虑在
于是如果在
于是把边按
注意:虽然这里
:::success[代码]
#include <algorithm>
#include <cassert>
#include <cstdint>
#include <cstring>
#include <iostream>
#include <vector>
using namespace std;
using vec = vector<int>;
namespace xuezy {
using ll = int64_t;
constexpr int inf = 0x3f3f3f3f;
constexpr int N = 1e5, N4 = N * 4, M = N * 40;
constexpr int XN = N + 10, XN4 = N4 + 10, XM = M + 10;
int n, m, c, len, cnt[XN]; ll val[XN];
struct Edge { int u, v, w, s, t, id; } ed[XN], ef[XN], eg[XN];
struct Data { int l, r, id; };
struct Queue {
vector<Data> q; int hh, tt;
inline bool em() { return hh == tt; }
inline Data& h() { return q[hh]; }
inline Data& t() { return q[tt - 1]; }
inline void pst(const Data& val) { q[tt++] = val; }
inline void ppt() { tt--; }
inline void pph() { hh++; }
} qq[XN];
ll f[XN];
void uniq(vec& a_in, vec& b_in, vec& l_in, vec& r_in) {
static int b[XN4]; len = 0;
b[++len] = 0, b[++len] = inf;
for (int i = 0; i < m; i++)
b[++len] = a_in[i], b[++len] = b_in[i];
for (int i = 0; i < c; i++)
b[++len] = l_in[i], b[++len] = r_in[i];
sort(b + 1, b + len + 1), len = unique(b + 1, b + len + 1) - b - 1;
for (int i = 0; i < m; i++)
a_in[i] = lower_bound(b + 1, b + len + 1, a_in[i]) - b,
b_in[i] = lower_bound(b + 1, b + len + 1, b_in[i]) - b;
for (int i = 0; i < c; i++)
l_in[i] = lower_bound(b + 1, b + len + 1, l_in[i]) - b,
r_in[i] = lower_bound(b + 1, b + len + 1, r_in[i]) - b;
}
struct SegmentTree {
#define mid ((l + r) >> 1)
int rt[XN4], idx;
struct Pos { int x, y; } p[XN];
struct Node { int ls, rs, sum; } e[XM];
vector<int> pv[XN4];
void pushup(int u) { e[u].sum = e[e[u].ls].sum + e[e[u].rs].sum; }
int update(int v, int pos, int val, int l = 1, int r = len) {
int u = ++idx; assert(u < M); e[u] = e[v];
if (l == r) return e[u].sum += val, u;
if (pos <= mid) e[u].ls = update(e[v].ls, pos, val, l, mid);
else e[u].rs = update(e[v].rs, pos, val, mid + 1, r);
return pushup(u), u;
}
int _query(int u, int v, int ql, int qr, int l = 1, int r = len) {
if (ql <= l && r <= qr) return e[v].sum - e[u].sum;
if (qr <= mid) return _query(e[u].ls, e[v].ls, ql, qr, l, mid);
if (ql > mid) return _query(e[u].rs, e[v].rs, ql, qr, mid + 1, r);
return _query(e[u].ls, e[v].ls, ql, qr, l, mid)
+ _query(e[u].rs, e[v].rs, ql, qr, mid + 1, r);
}
void init() {
idx = 0, rt[0] = 0;
for (int i = 1; i <= c; i++)
pv[p[i].x].push_back(p[i].y);
for (int i = 1; i <= len; i++) {
rt[i] = rt[i - 1];
for (int j : pv[i])
rt[i] = update(rt[i], j, 1);
}
}
inline int query(int x1, int x2, int y1, int y2) {
if (x1 > x2 || y1 > y2) return 0;
return _query(rt[x1 - 1], rt[x2], y1, y2);
}
inline int query(int l, int r) {
return query(l, r, l, r);
}
#undef mid
} seg;
inline ll cost(int j, int i, int u) {
return f[j] + val[u] * seg.query(ed[j].t + 1, i - 1);
}
void movupd(const Edge& ed) {
if (f[ed.id] == inf) return;
int tar = ed.v, tim = ed.t; Queue& q = qq[tar];
if (!q.em() && q.h().r < tim) q.pph();
if (!q.em()) q.h().l = tim;
while (!q.em() && cost(q.t().id, q.t().l, tar) >= cost(ed.id, q.t().l, tar)) q.ppt();
if (q.em()) q.pst(Data{tim, len, ed.id});
else if (cost(q.t().id, q.t().r, tar) >= cost(ed.id, q.t().r, tar)) {
int l = q.t().l, r = q.t().r;
while (l < r) {
int mid = (l + r) >> 1;
if (cost(q.t().id, mid, tar) >= cost(ed.id, mid, tar)) r = mid;
else l = mid + 1;
}
if ((q.t().r = l - 1) < q.t().l) q.ppt();
q.pst(Data{l, len, ed.id});
} else if (q.t().r < len)
q.pst(Data{q.t().r + 1, len, ed.id});
}
void movqry(const Edge& ed) {
int tar = ed.u, tim = ed.s; Queue& q = qq[tar];
while (!q.em() && q.h().r < tim) q.pph();
if (!ed.id) f[ed.id] = 0;
else if (q.em()) f[ed.id] = inf;
else f[ed.id] = cost(q.h().id, tim, tar) + ed.w;
}
ll solve(int n_in, int m_in, int w_in, vec t_in, vec x_in, vec y_in, vec a_in, vec b_in, vec c_in, vec l_in, vec r_in) {
n = n_in, m = m_in, c = w_in;
uniq(a_in, b_in, l_in, r_in);
for (int i = 1; i <= n; i++)
val[i] = t_in[i - 1];
for (int i = 1; i <= m; i++)
ed[i] = Edge{x_in[i - 1] + 1, y_in[i - 1] + 1, c_in[i - 1], a_in[i - 1], b_in[i - 1], i};
for (int i = 1; i <= c; i++)
seg.p[i] = SegmentTree::Pos{l_in[i - 1], r_in[i - 1]};
seg.init();
for (int i = 1; i <= n; i++) cnt[i] = 0;
ed[0] = Edge{1, 1, 0, 1, 1, 0}, cnt[1]++;
++m, ed[m] = Edge{n, n, 0, len, len, m}, cnt[n]++;
for (int i = 0; i <= m; i++)
ef[i] = eg[i] = ed[i], cnt[ed[i].v]++;
for (int i = 1; i <= n; i++) qq[i].q.resize(cnt[i]);
sort(ef + 0, ef + m + 1, [](const Edge& i, const Edge& j) { return i.t < j.t; });
sort(eg + 0, eg + m + 1, [](const Edge& i, const Edge& j) { return i.s < j.s; });
for (int i = 0, j = 1; i < m || j <= m; ) {
if (j > m || (i < m && ef[i].t <= eg[j].s))
movupd(ef[i++]);
else movqry(eg[j++]);
}
return f[m] == inf ? -1 : f[m];
}
};
long long solve(int N, int M, int W, vec T, vec X, vec Y, vec A, vec B, vec C, vec L, vec R) {
return xuezy::solve(N, M, W, T, X, Y, A, B, C, L, R);
}
:::