P11233 [CSP-S 2024] 染色

· · 题解

P11233 [CSP-S 2024] 染色

题意简述

给定一段序列,对其进行二染色。记每一个位置 i 的上一个颜色相同的位置为 last_i,则当 a_i=a_{last_i} 时,会对答案产生 a_i 的贡献。求整个序列最大的答案为多少。

思路

step 1

一眼 dp。由于每个位置产生贡献的条件是上个颜色相同的位置,所以对于每个位置,我们只需要关注上个颜色相同的位置就可以。所以自然而然,我们设 f_{i, j} 为前 i 位颜色已经确定,并且与第 i 位颜色相同的上一个位置是 j 的最大答案。那么状态转移方程就很好想了:

位置 i 与位置 i- 1 同色:

f_{i,j}=f_{i-1,j}(+a_i)(a_i==a_{i-1}时)

位置 i 与位置 j 同色:

f_{i,i-1}=f_{i-1,j}(+a_i)(a_i==a_j时)

显然,这个转移枚举 i 是一个 n,枚举 j 是个 nO(n^2),而 i 可以滚动数组滚掉,可以拿到 50 tps。

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define endl '\n'
const int N = 200010;

ll a[N];
ll f[2][N];

signed main(){
    int t;
    cin >> t;
    while(t --){
        memset(f, 0, sizeof f);
        int n;
        cin >> n;
        for(int i = 1; i <= n; i ++){
            cin >> a[i];
        }
        ll maxx = 0;
        for(int i = 1; i <= n; i ++){
            for(int j = 1; j < i; j ++){
                ll mid = 0;
                if(a[i] == a[i - 1])mid = a[i];
                f[i & 1][j] = max(f[i & 1][j], f[(i - 1) & 1][j] + mid);
                mid = 0;
                if(a[j] == a[i])mid = a[j];
                f[i & 1][i - 1] = max(f[i & 1][i - 1], f[(i - 1) & 1][j] + mid);
                maxx = max(maxx, f[i & 1][j]);
            }
        }
        cout << maxx << endl;
    }
    return 0;
}

step 2

然后由于我们知道每个位置产生的贡献只跟他上个颜色相同位置的值有关,所以我们可以换个状态转移方程:记 f_{i,j} 为前 i 位已经确定,并且与第 i 位颜色相同的上一位的数值为 j 的最大答案。所以这个的状态转移方程就是:

位置 i 与位置 j 同色:

f_{i,a_{i-1}}=f_{i-1,j}(+a_i)(a_i == j时)

位置 i 与位置 i- 1 同色:

f_{i,j}=f_{i-1,j}(+a_i)(a_i == a_{i - 1}时)

那么我们就可以获得 65 的高分(特殊性质 a_i \le10)。


#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define endl '\n'
const int N = 1000010;

ll a[N]; ll f[2][N];

signed main(){ int t; cin >> t; while(t --){ memset(f, -0x3f, sizeof f); int n; cin >> n; ll max_a = 0; for(int i = 1; i <= n; i ++){ cin >> a[i]; max_a = max(max_a, a[i]); } ll maxx = 0; f[1][a[1]] = 0; for(int i = 2; i <= n; i ++){ f[i & 1][a[i]] = max(0ll, f[i & 1][a[i]]); for(int j = 1; j <= max_a; j ++){ ll mid = 0; if(a[i] == j)mid = a[i]; f[i & 1][a[i - 1]] = max(f[i & 1][a[i - 1]], f[(i - 1) & 1][j] + mid); mid = 0; if(a[i] == a[i - 1])mid = a[i]; f[i & 1][j] = max(f[i & 1][j], f[(i - 1) & 1][j] + mid); } for(int j = 1; j <= max_a; j ++)maxx = max(maxx, f[i & 1][j]);

    }
    cout << maxx << endl;
}
return 0;

}

### step 3
然后我们观察代码,每一次转移的时候,都是只有 $a_i = a_{i - 1}$ 需要改,剩下的都是从 $i-1$  直接转移过来。所以我们可以维护线段树,对于上面两个点单点修改,剩下的区间修改就可以AC了。
```cpp
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define endl '\n'
const int N = 200010, M = 1000010;
const ll  MAXN = 0x3f3f3f3f3f3f3f3f;

ll a[N];
// ll f[2][N];
struct node{
    int l, r;
    ll w, flag;
}tr[M * 4];

void up(int u){
    tr[u]. w = max(tr[u << 1]. w, tr[u << 1 | 1]. w);
    return;
}

void down(int u){
    if(tr[u]. flag){
        tr[u << 1]. w += tr[u]. flag;
        tr[u << 1 | 1]. w += tr[u]. flag;
        tr[u << 1]. flag += tr[u]. flag;
        tr[u << 1 | 1]. flag += tr[u]. flag;
        tr[u]. flag = 0;
    }
    return;
}

void build(int u, int l, int r){
    tr[u] = {l, r};
    if(l == r){
        tr[u]. w = -MAXN;
        return;
    }
    int mid = (l + r) >> 1;
    build(u << 1, l, mid);
    build(u << 1 | 1, mid + 1, r);
    up(u);
    return;
}
void change_one(int u, int x, ll w){
    if(tr[u]. l == x && tr[u]. r == x){
        tr[u]. w = max(tr[u]. w, w);
        return;
    }
    down(u);
    int mid = (tr[u]. l + tr[u]. r) >> 1;
    if(x <= mid)change_one(u << 1, x, w);
    else change_one(u << 1 | 1, x, w);
    up(u);
    return;
}

ll query(int u, int x){
    if(tr[u]. l == x && tr[u]. r == x){
        ll mid = tr[u]. w;
        if(tr[u]. w < 0)tr[u]. w = 0;
        return mid;
    }
    ll ans = 0;
    down(u);
    int mid = (tr[u]. l + tr[u]. r) >> 1;
    if(x <= mid)ans = query(u << 1, x);
    else ans = query(u << 1 | 1, x);
    up(u);
    return ans;
}
inline int read() {

    int x = 0, f = 1;

    char c = getchar();

    if(c < '0' || c > '9') {

        if (c == '-') f = -1;

        c = getchar();

    }

    while (c >= '0' && c <= '9') {

        x = x * 10 + c - '0';

        c = getchar();

    }

    return x * f;

}
signed main(){
    ios::sync_with_stdio(0);
    cin. tie(0), cout. tie(0);
    int t;
    t = read();
    while(t --){
        int n;
        n = read();
        ll max_a = 0;
        for(int i = 1; i <= n; i ++){
            a[i] = read();
            max_a = max(max_a, a[i]);
        }
        build(1, 1, max_a);
        query(1, a[1]);
        ll maxx = 0;
        for(int i = 2; i <= n; i ++){
            ll mid = tr[1]. w;
            mid = max(mid, query(1, a[i]) + a[i]);
            if(a[i] == a[i - 1]){
                tr[1]. w += a[i];
                tr[1]. flag += a[i];
            }
            change_one(1, a[i - 1], mid);

        }
        cout << tr[1]. w << endl;
    }
    return 0;
}