Accelerated 01 Knapsack Algorithm and the Quadrangle Inequality
Accelerated 01 Knapsack Algorithm and the Quadrangle Inequality
7ue9ueue.github.io 好像账号丢了... 先放这里吧...
Reference:
Capacitated Dynamic Programming: Faster Knapsack and Graph Algorithms https://arxiv.org/abs/1802.06440
SMAWK Algorithm https://noshi91.github.io/Library/algorithm/smawk.cpp
LARSCH Algorithm https://noshi91.github.io/Library/algorithm/larsch.cpp
I was revisiting my personal statement and I read Professor Kyriakos Axiotis and Professor Christos Tzamos's paper again. I implemented their ideas on the monotonicity of the knapsack problem and record this as a notes of what I learned. Basically, the SMAWK algorithm and the LARSCH algorithm provides an alternative approach for the traditional binary search method for dynamic programming problems that involved the quadrangle inequality.
Quadrangle Inequality
Reference:
Quadrangle Inequality trick for dynamic programs https://tryalgo.org/en/graphs/2022/11/03/optimal-search-tree/
Quadrangle Inequality Properties https://codeforces.com/blog/entry/86306
The function
If the function
To decide whether
Which can be proved by induction.
The optimization of dynamic programming quadrangle inequality is a method discovered by Knuth and sometimes known as Knuth's Optimization. Knuth's Optimization solves:
Note: (3) requires
w(b,c) \leq w(a,d) . as well as the quadrangle inequality. However in most cases this is easy.The original method in Knuth, D. E., Optimum binary search trees, Acta Informatica, 1(1), pages 14–25, 1971. Seems only account for case (3). However studying this optimization more we can realize that (4) is the generalized idea for the optimization, and every other case can be converted to case (4). I don't have 39,95 € so I don't really know what is inside this link.
The Optimization
For dynamic programming
Define
Theorem 1 (monotone theorem?) (I actually don't know the name)
Proof
Suppose for some
Then by the quadrangle inequality,
However,
Which means that:
Contradiction.
To obtain other forms of the optimization, just do a suitable substitution. If
The Algorithm
There are commonly two algorithms to apply Theorem 1 (not SMAWK and LARSCH):
Algorithm 1 Just binary search
- In order to find
f(l\dots r) , we use binary search to break this section into two parts,f(l \dots mid-1) andf(mid+1 \dots r) . - Why
f(mid) is missing? This is because we can evaluatef(mid) and\text{opt}(mid) with brute force. - After finding
f(mid) and\text{opt}(mid) , apply recursing to solve for the left part and the right part.
The algorithm runs with
Another way to understand this method is compare it to merge sort. It's the same thing.
//Reference: https://oi-wiki.org/dp/opt/quadrangle/
int w(int j, int i); //O(1)
void DP(int l, int r, int ql, int qr) { //O(nlogn)
int mid = (l+r)/2;
int k = ql;
for (int j=ql;j<=min(qr,mid-1);j++)
if (w(j,mid)<w(k,mid)) k = j;
f[mid] = w(k,mid);
if (l<mid) DP(l,mid-1,ql,k);
if (r>mid) DP(mid+1,r,k,qr);
}
Algorithm 2 Binary search using a queue
For the set of
This method is not the focus of this article as clearly this method is more complicated than the previous one. Its only used when
Here is a Chinese source for this method. https://oi-wiki.org/dp/opt/quadrangle/.
The 01 knapsack Problem
Problem Link https://loj.ac/p/6039
For the 01 knapsack problem, Define the value of the object as
Define the prefix sum array as
Define
Note that
c can be considered as a constant.
Define
Lemma 1
LHS is the element with rank
Rewriting (14) will obtain:
Which returns to Case (6).
Lemma 2
Case (3),(5),(6) can be done similarly. And of course it's left as exercise for readers.
SMAWK algorithm and LARSCH algorithm
I'm not really sure of the mechanism of these two algorithms yet, but SMAWK solves Case (6) in linear time and LARSCH solves Case (5) in linear time. Therefore, the time complexity of 01 knapsack can be optimized from
Appendix
Code 1 Binary Search Solution
#include "bits/stdc++.h"
#define int long long
using namespace std;
const int maxn = 5e4 + 10;
int n, m, lim;
int sz[maxn], f[maxn], g[maxn], t[maxn];
vector<int> s[maxn], a[maxn];
void solve(int l, int r, int ql, int qr, int x)
{
if (l == r)
{
int sl = max(ql, l - sz[x]);
int sr = min(qr, l - 1);
for (int i = sl; i <= sr; i++)
{
f[t[l]] = max(f[t[l]], g[t[i]] + a[x][l - i]);
}
return;
}
int mid = (l + r) >> 1;
int pos = mid;
int sl = max(ql, mid - sz[x]);
int sr = min(qr, mid - 1);
for (int i = sl; i <= sr; i++)
{
if (f[t[mid]] <= g[t[i]] + a[x][mid - i])
{
f[t[mid]] = g[t[i]] + a[x][mid - i];
pos = i;
}
}
solve(l, mid, ql, pos, x);
solve(mid + 1, r, pos, qr, x);
}
signed main()
{
cin >> n >> m;
for (int i = 1; i <= n; i++)
{
int x;
cin >> x;
int y;
cin >> y;
lim = max(lim, x);
s[x].push_back(y);
}
for (int i = 1; i <= lim; i++)
{
sort(s[i].begin(), s[i].end());
reverse(s[i].begin(), s[i].end());
a[i].resize(s[i].size() + 1);
for (int x : s[i])
{
a[i][++sz[i]] = x;
a[i][sz[i]] += a[i][sz[i] - 1];
}
for (int j = 0; j <= m; j++)
g[j] = f[j];
for (int j = 0; j < i; j++)
{
int c = 0;
for (int k = j; k <= m; k += i)
t[++c] = k;
if (c >= 2)
solve(2, c, 1, c, i);
}
vector<int>().swap(a[i]);
}
for (int i = 1; i <= m; i++)
cout << f[i] << " ";
}
Code 2 LARSCH Algorithm
#include "bits/stdc++.h"
#define int long long
using namespace std;
// Larsch Algorithm, zero indexed, lower triangular matrix.
// https://noshi91.github.io/Library/algorithm/larsch.cpp
template <class T>
class larsch
{
struct reduce_row;
struct reduce_col;
struct reduce_row
{
int n;
std::function<T(int, int)> f;
int cur_row;
int state;
std::unique_ptr<reduce_col> rec;
reduce_row(int n_) : n(n_), f(), cur_row(0), state(0), rec()
{
const int m = n / 2;
if (m != 0)
{
rec = std::make_unique<reduce_col>(m);
}
}
void set_f(std::function<T(int, int)> f_)
{
f = f_;
if (rec)
{
rec->set_f([&](int i, int j) -> T
{ return f(2 * i + 1, j); });
}
}
int get_argmin()
{
const int cur_row_ = cur_row;
cur_row += 1;
if (cur_row_ % 2 == 0)
{
const int prev_argmin = state;
const int next_argmin = [&]()
{
if (cur_row_ + 1 == n)
{
return n - 1;
}
else
{
return rec->get_argmin();
}
}();
state = next_argmin;
int ret = prev_argmin;
for (int j = prev_argmin + 1; j <= next_argmin; j += 1)
{
if (f(cur_row_, ret) > f(cur_row_, j))
{
ret = j;
}
}
return ret;
}
else
{
if (f(cur_row_, state) <= f(cur_row_, cur_row_))
{
return state;
}
else
{
return cur_row_;
}
}
}
};
struct reduce_col
{
int n;
std::function<T(int, int)> f;
int cur_row;
std::vector<int> cols;
reduce_row rec;
reduce_col(int n_) : n(n_), f(), cur_row(0), cols(), rec(n) {}
void set_f(std::function<T(int, int)> f_)
{
f = f_;
rec.set_f([&](int i, int j) -> T
{ return f(i, cols[j]); });
}
int get_argmin()
{
const int cur_row_ = cur_row;
cur_row += 1;
const auto cs = [&]() -> std::vector<int>
{
if (cur_row_ == 0)
{
return {{0}};
}
else
{
return {{2 * cur_row_ - 1, 2 * cur_row_}};
}
}();
for (const int j : cs)
{
while ([&]()
{
const int size = cols.size();
return size != cur_row_ && f(size - 1, cols.back()) > f(size - 1, j); }())
{
cols.pop_back();
}
if (cols.size() != n)
{
cols.push_back(j);
}
}
return cols[rec.get_argmin()];
}
};
std::unique_ptr<reduce_row> base;
public:
larsch(int n, std::function<T(int, int)> f)
: base(std::make_unique<reduce_row>(n))
{
base->set_f(f);
}
int get_argmin() { return base->get_argmin(); }
};
const int maxn = 5e4 + 5;
const int maxk = 305;
int n, m, lim;
int sz[maxk], f[maxn], g[maxn], t[maxn];
vector<int> s[maxk], a[maxk];
void solve(int c, int x)
{
const auto w = [&](const int ip, const int jp, const int x, const bool isLarsch) -> int
{
int i = ip;
int j = jp;
if (isLarsch)
i++, j++;
if (j >= i)
return 0;
else
return g[t[j]] + a[x][i - j];
};
larsch<int> opt(c, [&](int i, int j)
{ return -w(i, j, x, true); });
for (int i = 1; i <= c; i++)
{
int j = opt.get_argmin();
if (i == 1)
continue;
f[t[i]] = max(f[t[i]], w(i, j + 1, x, false));
}
}
signed main()
{
cin >> n >> m;
for (int i = 1; i <= n; i++)
{
int x;
cin >> x;
int y;
cin >> y;
lim = max(lim, x);
s[x].push_back(y);
}
for (int i = 1; i <= lim; i++)
{
if (!s[i].size())
continue;
sort(s[i].begin(), s[i].end());
reverse(s[i].begin(), s[i].end());
a[i].resize(m / i + 5);
a[i][0] = 0;
for (int j = 1; j < a[i].size(); j++)
{
if (j - 1 < s[i].size())
a[i][j] += s[i][j - 1];
a[i][j] += a[i][j - 1];
}
for (int j = 0; j <= m; j++)
g[j] = f[j];
for (int j = 0; j < i; j++)
{
int c = 0;
for (int k = j; k <= m; k += i)
t[++c] = k;
if (c >= 2)
solve(c, i);
}
vector<int>().swap(a[i]);
}
for (int i = 1; i <= m; i++)
{
cout << f[i] << " ";
}
}
Code 3 SMAWK Algorithm
#include "bits/stdc++.h"
#define int long long
using namespace std;
// SMAWK Algorithm, zero indexed, any matrix.
// https://noshi91.github.io/Library/algorithm/smawk.cpp
template <class Select>
std::vector<int> smawk(const int row_size, const int col_size, const Select &select)
{
using vec = vector<int>;
const function<vec(const vec &, const vec &)> solve =
[&](const vec &row, const vec &col) -> vec
{
const int n = row.size();
if (n == 0)
return {};
vec c2;
for (const int i : col)
{
while (!c2.empty() && select(row[c2.size() - 1], c2.back(), i))
c2.pop_back();
if (c2.size() < n)
c2.push_back(i);
}
vec r2;
for (int i = 1; i < n; i += 2)
r2.push_back(row[i]);
const vec a2 = solve(r2, c2);
vec ans(n);
for (int i = 0; i != a2.size(); i++)
ans[i * 2 + 1] = a2[i];
int j = 0;
for (int i = 0; i < n; i += 2)
{
ans[i] = c2[j];
const int end = i + 1 == n ? c2.back() : ans[i + 1];
while (c2[j] != end)
{
j++;
if (select(row[i], ans[i], c2[j]))
ans[i] = c2[j];
}
}
return ans;
};
vec row(row_size);
iota(row.begin(), row.end(), 0);
vec col(col_size);
iota(col.begin(), col.end(), 0);
return solve(row, col);
}
const int maxn = 5e4 + 5;
const int maxk = 305;
int n, m, lim;
int sz[maxk], f[maxn], g[maxn], t[maxn];
vector<int> s[maxk], a[maxk];
void solve(int c, int x)
{
const auto w = [&](const int ip, const int jp, const int x, const bool SMAWK) -> int
{
int i = ip;
int j = jp;
if (SMAWK)
i++, j++;
if (j >= i)
return 0;
else
return g[t[j]] + a[x][i - j];
};
const auto select = [&](const int i, const int j, const int k)
{
if (i <= k)
return false;
return w(i, j, x, true) <= w(i, k, x, true);
};
const vector<int> amax = smawk(c, c, select);
for (int i = 2; i <= c; i++)
{
f[t[i]] = max(f[t[i]], w(i, amax[i - 1] + 1, x, false));
}
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cin >> n >> m;
for (int i = 1; i <= n; i++)
{
int x;
cin >> x;
int y;
cin >> y;
lim = max(lim, x);
s[x].push_back(y);
}
for (int i = 1; i <= lim; i++)
{
if (!s[i].size())
continue;
sort(s[i].begin(), s[i].end());
reverse(s[i].begin(), s[i].end());
a[i].resize(m / i + 5);
a[i][0] = 0;
for (int j = 1; j < a[i].size(); j++)
{
if (j - 1 < s[i].size())
a[i][j] += s[i][j - 1];
a[i][j] += a[i][j - 1];
}
for (int j = 0; j <= m; j++)
g[j] = f[j];
for (int j = 0; j < i; j++)
{
int c = 0;
for (int k = j; k <= m; k += i)
t[++c] = k;
if (c >= 2)
solve(c, i);
}
vector<int>().swap(a[i]);
}
for (int i = 1; i <= m; i++)
cout << f[i] << " ";
}