Accelerated 01 Knapsack Algorithm and the Quadrangle Inequality

· · 个人记录

Accelerated 01 Knapsack Algorithm and the Quadrangle Inequality

7ue9ueue.github.io 好像账号丢了... 先放这里吧...

Reference:

Capacitated Dynamic Programming: Faster Knapsack and Graph Algorithms https://arxiv.org/abs/1802.06440

SMAWK Algorithm https://noshi91.github.io/Library/algorithm/smawk.cpp

LARSCH Algorithm https://noshi91.github.io/Library/algorithm/larsch.cpp

I was revisiting my personal statement and I read Professor Kyriakos Axiotis and Professor Christos Tzamos's paper again. I implemented their ideas on the monotonicity of the knapsack problem and record this as a notes of what I learned. Basically, the SMAWK algorithm and the LARSCH algorithm provides an alternative approach for the traditional binary search method for dynamic programming problems that involved the quadrangle inequality.

Quadrangle Inequality

Reference:

Quadrangle Inequality trick for dynamic programs https://tryalgo.org/en/graphs/2022/11/03/optimal-search-tree/

Quadrangle Inequality Properties https://codeforces.com/blog/entry/86306

The function w(x,y) satisfies quadrangle Inequality if:

a\leq b \leq c \leq d \Longrightarrow w(a,c)+w(b,d) \leq w(a,d)+w(b,c)

If the function w(x,y) is considered as a matrix A_{n,m}, then it could also been said that A_{n,m} is Monge. In competitive programming, the Monge matrix is less known and the name "quadrangle inequality" is more often used (at least in where I study). In my opinion they are the same thing but with different names.

To decide whether w(x,y) satisfies the quadrangle inequality, one only have to decide if

w(j,i)+w(j+1,i+1) \leq w(j+1,i)+w(j,i+1)

Which can be proved by induction.

The optimization of dynamic programming quadrangle inequality is a method discovered by Knuth and sometimes known as Knuth's Optimization. Knuth's Optimization solves:

f(i,j) = \min_k(f(i,k)+f(k+1,j)+w(j,i)) f(i) = \min_j(w(j,i)) f(i) = \min_j (f(j)+w(j,i)) f(i) = \min_j(g(j)+w(j,i))

Note: (3) requires w(b,c) \leq w(a,d). as well as the quadrangle inequality. However in most cases this is easy.

The original method in Knuth, D. E., Optimum binary search trees, Acta Informatica, 1(1), pages 14–25, 1971. Seems only account for case (3). However studying this optimization more we can realize that (4) is the generalized idea for the optimization, and every other case can be converted to case (4). I don't have 39,95 € so I don't really know what is inside this link.

The Optimization

For dynamic programming f

f(i) = \min_{j}(w(j,i))

Define \text{opt}(i) as the largest integer where

f(i) = w(\text{opt}(i),i)

Theorem 1 (monotone theorem?) (I actually don't know the name)

i_1 < i_2 \Longrightarrow \text{opt}(i_1) < \text{opt}(i_2)

Proof

Suppose for some i_1< i_2, j_1 = \text{opt}(i_1) \geq \text{opt}(i_2) = j_2. Then j_2 \leq j_1 \leq i_1 < i_2.

Then by the quadrangle inequality,

w(j_1,i_2) + w(j_2,i_1) \leq w(j_1,i_1) + w(j_2,i_2)

However,

w(j_2,i_1) \geq w(\text{opt}(i_1),i_1) = w(j_1,i_1) w(j_1,i_2) > w(\text{opt}(i_2),i_2) = w(j_2,i_2)

Which means that:

w(j_1,i_2) + w(j_2,i_1) > w(j_1,i_1) + w(j_2,i_2)

Contradiction. \square

To obtain other forms of the optimization, just do a suitable substitution. If \max is to replace \min, define new w'(x,y) = w(x,y). And prove can be done similarly. The next section gives an example of the application.

The Algorithm

There are commonly two algorithms to apply Theorem 1 (not SMAWK and LARSCH):

Algorithm 1 Just binary search

The algorithm runs with O(n\log n) time complexity. This is best for Case (6).

Another way to understand this method is compare it to merge sort. It's the same thing.

//Reference: https://oi-wiki.org/dp/opt/quadrangle/
int w(int j, int i); //O(1)

void DP(int l, int r, int ql, int qr) { //O(nlogn)
  int mid = (l+r)/2;
  int k = ql;
  for (int j=ql;j<=min(qr,mid-1);j++)
    if (w(j,mid)<w(k,mid)) k = j;
  f[mid] = w(k,mid);
  if (l<mid) DP(l,mid-1,ql,k);
  if (r>mid) DP(mid+1,r,k,qr);
}

Algorithm 2 Binary search using a queue

For the set of i such that j = \text{opt}(i), its clear that this set is a contiguous segment that can be denoted by [l_j,r_j]. Therefore, we can use a queue maintain a list of possible \text{opt}(i) that can be the answer in the future. When adding a new element, do binary search to pop elements from the back.

This method is not the focus of this article as clearly this method is more complicated than the previous one. Its only used when f(i) has self-dependencies (Case 5) but the knapsack problem is (Case 6).

Here is a Chinese source for this method. https://oi-wiki.org/dp/opt/quadrangle/.

The 01 knapsack Problem

Problem Link https://loj.ac/p/6039

For the 01 knapsack problem, Define the value of the object as v_i, the size of the object as s_i. Then, sort all objects with the same size in descending order and find the prefix sum for each possible size of objects.

Define the prefix sum array as p_i[k]. Where i is the i-th unique smallest element is s and k is the number of objects chosen with this size. Now, we can do (max,+) convolutions with the p_i array.

Define f_c(j) as the answer considering objects from p_{i...c} and the sum of sizes being j. Then

f_c(i) = \max_{j=0}^{i-jc \geq 0} (f_{c-1}(i-jc)+p_c[j])

Note that c can be considered as a constant.

Define

w(j,i) = p_c[i-j]

Lemma 1 w(j,i) follows the quadrangle inequality.

w(j,i)+w(j+1,i+1) \leq w(j+1,i)+w(j,i+1) \\ p_c[j-i]+p_c[j-i] \leq p_c[j-i-1] + p_c[j-i+1] \\ p_c[j-i]-p_c[j-i-1]\leq p_c[j-i+1]-p_c[j-i]\\

LHS is the element with rank j-i and RHS is the element with rank j-i+1. So RHS is bigger as the elements are sorted. \square

Rewriting (14) will obtain:

f_c(i) = \max_{j=0}^{i-jc \geq 0} (f_{c-1}(i-jc)+p_c[j]) \\ f_c(i) = \max_{i-(i-j)c\geq 0}^{j=i} (f_{c-1}(i-(i-j)c)+p_c[i-j]) \\ f_c(i) = \max_{i-(i-j)c\geq 0}^{j=i} (g(j)+w(j,i)) \\

Which returns to Case (6).

Lemma 2 w'(j,i) = g(j)+w(j,i) follows the quadrangle inequality.

w'(j,i)+w'(j+1,i+1) \leq w'(j+1,i)+w'(j,i+1) \\ w(j,i)+g(j)+w(j+1,i+1)+g(j+1) \leq w(j+1,i)+g(j+1)+w(j,i+1)+g(j) \\ w(j,i)+w(j+1,i+1) \leq w(j+1,i)+w(j,i+1) \\ \square

Case (3),(5),(6) can be done similarly. And of course it's left as exercise for readers.

SMAWK algorithm and LARSCH algorithm

I'm not really sure of the mechanism of these two algorithms yet, but SMAWK solves Case (6) in linear time and LARSCH solves Case (5) in linear time. Therefore, the time complexity of 01 knapsack can be optimized from O(nW) to O(DW\log D) to O(DW). Where n is number of objects, W is desired total weight, and D is number of distinct weight.

Appendix

Code 1 Binary Search Solution

#include "bits/stdc++.h"
#define int long long
using namespace std;

const int maxn = 5e4 + 10;
int n, m, lim;
int sz[maxn], f[maxn], g[maxn], t[maxn];
vector<int> s[maxn], a[maxn];

void solve(int l, int r, int ql, int qr, int x)
{
    if (l == r)
    {
        int sl = max(ql, l - sz[x]);
        int sr = min(qr, l - 1);
        for (int i = sl; i <= sr; i++)
        {
            f[t[l]] = max(f[t[l]], g[t[i]] + a[x][l - i]);
        }
        return;
    }
    int mid = (l + r) >> 1;
    int pos = mid;
    int sl = max(ql, mid - sz[x]);
    int sr = min(qr, mid - 1);
    for (int i = sl; i <= sr; i++)
    {
        if (f[t[mid]] <= g[t[i]] + a[x][mid - i])
        {
            f[t[mid]] = g[t[i]] + a[x][mid - i];
            pos = i;
        }
    }
    solve(l, mid, ql, pos, x);
    solve(mid + 1, r, pos, qr, x);
}
signed main()
{
    cin >> n >> m;
    for (int i = 1; i <= n; i++)
    {
        int x;
        cin >> x;
        int y;
        cin >> y;
        lim = max(lim, x);
        s[x].push_back(y);
    }
    for (int i = 1; i <= lim; i++)
    {
        sort(s[i].begin(), s[i].end());
        reverse(s[i].begin(), s[i].end());
        a[i].resize(s[i].size() + 1);
        for (int x : s[i])
        {
            a[i][++sz[i]] = x;
            a[i][sz[i]] += a[i][sz[i] - 1];
        }
        for (int j = 0; j <= m; j++)
            g[j] = f[j];
        for (int j = 0; j < i; j++)
        {
            int c = 0;
            for (int k = j; k <= m; k += i)
                t[++c] = k;
            if (c >= 2)
                solve(2, c, 1, c, i);
        }
        vector<int>().swap(a[i]);
    }
    for (int i = 1; i <= m; i++)
        cout << f[i] << " ";
}

Code 2 LARSCH Algorithm

#include "bits/stdc++.h"
#define int long long
using namespace std;

// Larsch Algorithm, zero indexed, lower triangular matrix.
// https://noshi91.github.io/Library/algorithm/larsch.cpp
template <class T>
class larsch
{
    struct reduce_row;
    struct reduce_col;

    struct reduce_row
    {
        int n;
        std::function<T(int, int)> f;
        int cur_row;
        int state;
        std::unique_ptr<reduce_col> rec;

        reduce_row(int n_) : n(n_), f(), cur_row(0), state(0), rec()
        {
            const int m = n / 2;
            if (m != 0)
            {
                rec = std::make_unique<reduce_col>(m);
            }
        }

        void set_f(std::function<T(int, int)> f_)
        {
            f = f_;
            if (rec)
            {
                rec->set_f([&](int i, int j) -> T
                           { return f(2 * i + 1, j); });
            }
        }

        int get_argmin()
        {
            const int cur_row_ = cur_row;
            cur_row += 1;
            if (cur_row_ % 2 == 0)
            {
                const int prev_argmin = state;
                const int next_argmin = [&]()
                {
                    if (cur_row_ + 1 == n)
                    {
                        return n - 1;
                    }
                    else
                    {
                        return rec->get_argmin();
                    }
                }();
                state = next_argmin;
                int ret = prev_argmin;
                for (int j = prev_argmin + 1; j <= next_argmin; j += 1)
                {
                    if (f(cur_row_, ret) > f(cur_row_, j))
                    {
                        ret = j;
                    }
                }
                return ret;
            }
            else
            {
                if (f(cur_row_, state) <= f(cur_row_, cur_row_))
                {
                    return state;
                }
                else
                {
                    return cur_row_;
                }
            }
        }
    };

    struct reduce_col
    {
        int n;
        std::function<T(int, int)> f;
        int cur_row;
        std::vector<int> cols;
        reduce_row rec;

        reduce_col(int n_) : n(n_), f(), cur_row(0), cols(), rec(n) {}

        void set_f(std::function<T(int, int)> f_)
        {
            f = f_;
            rec.set_f([&](int i, int j) -> T
                      { return f(i, cols[j]); });
        }

        int get_argmin()
        {
            const int cur_row_ = cur_row;
            cur_row += 1;
            const auto cs = [&]() -> std::vector<int>
            {
                if (cur_row_ == 0)
                {
                    return {{0}};
                }
                else
                {
                    return {{2 * cur_row_ - 1, 2 * cur_row_}};
                }
            }();
            for (const int j : cs)
            {
                while ([&]()
                       {
            const int size = cols.size();
            return size != cur_row_ && f(size - 1, cols.back()) > f(size - 1, j); }())
                {
                    cols.pop_back();
                }
                if (cols.size() != n)
                {
                    cols.push_back(j);
                }
            }
            return cols[rec.get_argmin()];
        }
    };

    std::unique_ptr<reduce_row> base;

public:
    larsch(int n, std::function<T(int, int)> f)
        : base(std::make_unique<reduce_row>(n))
    {
        base->set_f(f);
    }

    int get_argmin() { return base->get_argmin(); }
};

const int maxn = 5e4 + 5;
const int maxk = 305;
int n, m, lim;
int sz[maxk], f[maxn], g[maxn], t[maxn];
vector<int> s[maxk], a[maxk];

void solve(int c, int x)
{
    const auto w = [&](const int ip, const int jp, const int x, const bool isLarsch) -> int
    {
        int i = ip;
        int j = jp;
        if (isLarsch)
            i++, j++;
        if (j >= i)
            return 0;
        else
            return g[t[j]] + a[x][i - j];
    };
    larsch<int> opt(c, [&](int i, int j)
                    { return -w(i, j, x, true); });
    for (int i = 1; i <= c; i++)
    {
        int j = opt.get_argmin();
        if (i == 1)
            continue;
        f[t[i]] = max(f[t[i]], w(i, j + 1, x, false));
    }
}

signed main()
{
    cin >> n >> m;
    for (int i = 1; i <= n; i++)
    {
        int x;
        cin >> x;
        int y;
        cin >> y;
        lim = max(lim, x);
        s[x].push_back(y);
    }
    for (int i = 1; i <= lim; i++)
    {
        if (!s[i].size())
            continue;
        sort(s[i].begin(), s[i].end());
        reverse(s[i].begin(), s[i].end());
        a[i].resize(m / i + 5);
        a[i][0] = 0;
        for (int j = 1; j < a[i].size(); j++)
        {
            if (j - 1 < s[i].size())
                a[i][j] += s[i][j - 1];
            a[i][j] += a[i][j - 1];
        }
        for (int j = 0; j <= m; j++)
            g[j] = f[j];
        for (int j = 0; j < i; j++)
        {
            int c = 0;
            for (int k = j; k <= m; k += i)
                t[++c] = k;
            if (c >= 2)
                solve(c, i);
        }
        vector<int>().swap(a[i]);
    }
    for (int i = 1; i <= m; i++)
    {
        cout << f[i] << " ";
    }
}

Code 3 SMAWK Algorithm

#include "bits/stdc++.h"
#define int long long
using namespace std;

// SMAWK Algorithm, zero indexed, any matrix.
// https://noshi91.github.io/Library/algorithm/smawk.cpp
template <class Select>
std::vector<int> smawk(const int row_size, const int col_size, const Select &select)
{
    using vec = vector<int>;
    const function<vec(const vec &, const vec &)> solve =
        [&](const vec &row, const vec &col) -> vec
    {
        const int n = row.size();
        if (n == 0)
            return {};
        vec c2;
        for (const int i : col)
        {
            while (!c2.empty() && select(row[c2.size() - 1], c2.back(), i))
                c2.pop_back();
            if (c2.size() < n)
                c2.push_back(i);
        }
        vec r2;
        for (int i = 1; i < n; i += 2)
            r2.push_back(row[i]);
        const vec a2 = solve(r2, c2);
        vec ans(n);
        for (int i = 0; i != a2.size(); i++)
            ans[i * 2 + 1] = a2[i];
        int j = 0;
        for (int i = 0; i < n; i += 2)
        {
            ans[i] = c2[j];
            const int end = i + 1 == n ? c2.back() : ans[i + 1];
            while (c2[j] != end)
            {
                j++;
                if (select(row[i], ans[i], c2[j]))
                    ans[i] = c2[j];
            }
        }
        return ans;
    };
    vec row(row_size);
    iota(row.begin(), row.end(), 0);
    vec col(col_size);
    iota(col.begin(), col.end(), 0);
    return solve(row, col);
}

const int maxn = 5e4 + 5;
const int maxk = 305;
int n, m, lim;
int sz[maxk], f[maxn], g[maxn], t[maxn];
vector<int> s[maxk], a[maxk];

void solve(int c, int x)
{
    const auto w = [&](const int ip, const int jp, const int x, const bool SMAWK) -> int
    {
        int i = ip;
        int j = jp;
        if (SMAWK)
            i++, j++;
        if (j >= i)
            return 0;
        else
            return g[t[j]] + a[x][i - j];
    };
    const auto select = [&](const int i, const int j, const int k)
    {
        if (i <= k)
            return false;
        return w(i, j, x, true) <= w(i, k, x, true);
    };
    const vector<int> amax = smawk(c, c, select);
    for (int i = 2; i <= c; i++)
    {
        f[t[i]] = max(f[t[i]], w(i, amax[i - 1] + 1, x, false));
    }
}

signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cin >> n >> m;
    for (int i = 1; i <= n; i++)
    {
        int x;
        cin >> x;
        int y;
        cin >> y;
        lim = max(lim, x);
        s[x].push_back(y);
    }
    for (int i = 1; i <= lim; i++)
    {

        if (!s[i].size())
            continue;
        sort(s[i].begin(), s[i].end());
        reverse(s[i].begin(), s[i].end());

        a[i].resize(m / i + 5);
        a[i][0] = 0;

        for (int j = 1; j < a[i].size(); j++)
        {
            if (j - 1 < s[i].size())
                a[i][j] += s[i][j - 1];
            a[i][j] += a[i][j - 1];
        }

        for (int j = 0; j <= m; j++)
            g[j] = f[j];

        for (int j = 0; j < i; j++)
        {
            int c = 0;
            for (int k = j; k <= m; k += i)
                t[++c] = k;
            if (c >= 2)
                solve(c, i);
        }

        vector<int>().swap(a[i]);
    }
    for (int i = 1; i <= m; i++)
        cout << f[i] << " ";
}