P6240 好吃的题目

· · 题解

怎么全是猫树分治,回滚莫队来一发。

先考虑暴力 DP,也就是直接每个区间单独跑一遍 DP。

然后我们发现 [l, r] 的 DP 状态是从 [l, r - 1] 转移来的,进而稍加思考,发现也不是不能从 [l + 1, r] 转移(废话)。

所以考虑优化转移路径,利用曾经算出来过的 [l_i,r_i] 来算出 [l_j,r_j] 的答案。

又发现从 [l, r + 1][l - 1,r] 转移到 [l,r] 不太现实。

所以考虑用回滚莫队维护转移。

直接写就行,想明白后基本就是板子了,把指针转移换成 O(h) 的 DP 即可。

总复杂度 O(n\sqrt m h),比较劣。

code

// code by 樓影沫瞬_Hz17
#include <bits/stdc++.h>

using namespace std;

#define int uint
constexpr int N = 4e4 + 10, B = 80;

int n, m;

int v[N], w[N], pos[N];

struct Que {
    int l, r, m, id;
} ;
vector<Que> Q[N / B + 10];

static int dpv[210], tmpv[210];
static int L[N / B + 10], R[N / B + 10];
int ans[200000 + 10], mxm[N / B + 10];

signed main() {
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin >> n >> m;
    for(int i = 1; i <= n; i ++) cin >> w[i];
    for(int i = 1; i <= n; i ++) cin >> v[i];

    for(int i = 1; R[i - 1] != n; i ++) {
        L[i] = (i - 1) * B + 1, R[i] = min(n, i * B);
        for(int j = L[i]; j <= R[i]; j ++) pos[j] = i;
    } 

    for(int i = 1, l, r, M; i <= m; i ++) {
        cin >> l >> r >> M;
        if(pos[r] - pos[l] <= 1) {
            memset(dpv, 0, sizeof dpv);
            for(int i = l; i <= r; i ++) 
                for(int j = M; j >= w[i]; j --) 
                    if(dpv[j] < dpv[j - w[i]] + v[i]) dpv[j] = dpv[j - w[i]] + v[i];
            ans[i] = dpv[M];
        }
        else Q[pos[l]].emplace_back((Que){l, r, M, i}), mxm[pos[l]] = max(mxm[pos[l]], M);
    }

    for(int i = 1; R[i - 1] != n; i ++) {
        if(Q[i].empty()) continue;
        sort(Q[i].begin(), Q[i].end(), [](Que a, Que b) { return a.r < b.r ; });
    }

    for(int i = 1; R[i - 1] != n; i ++) {
        int r = R[i], l, M = mxm[i];
        memset(dpv, 0, sizeof dpv);
        for(Que q : Q[i]) {
            while(r < q.r) {
                r ++;
                for(int j = M; j >= w[r]; j --) 
                    if(dpv[j] < dpv[j - w[r]] + v[r]) dpv[j] = dpv[j - w[r]] + v[r];
            }
            memcpy(tmpv, dpv, (M + 1) * 4);
            l = L[i + 1];
            while(l > q.l) {
                l --;
                for(int j = M; j >= w[l]; j --) 
                    if(tmpv[j] < tmpv[j - w[l]] + v[l]) tmpv[j] = tmpv[j - w[l]] + v[l];
            }
            ans[q.id] = tmpv[q.m];
        }
    }

    for(int i = 1; i <= m; i ++) cout << ans[i] << '\n';
}