题解:P9168 [省选联考 2023] 人员调度

· · 题解

一、前言

感觉好多题解都看不懂,讲的不太清楚,故写一个自认为比较清晰的题解。

二、题意描述

给定一个以 1 为根的树,还有 k 个人,每个人有一个属性 (u,v) ,表示这个人只可以放在 u 的子树内,且这个人的能力值是 v。请你安排一种安放人员的方法(可以有人不在树上),使得任意一个节点上最多只有一个人。最大化 \sum v 。有 m 次询问,每次加入 / 删除一个人。 请你对这些询问依次回答。

三、solution

一种想法

首先我们发现不管怎么样我们都需要求解出初始情况下的 ans ,故我们考虑怎么在初始情况下求解该问题。一个很直接的想法是把人员按照能力值从大到小排序,然后随意放置。但是这个贪心明显是假的,因为我们在随便放的时候不知道会不会影响后面的人(也就是从祖先放的时候有可能影响到 ta 的后代)。我们反过来想,类比树形 DP ,从叶子开始放,然后如果 u 的子树没有被填满就直接放初始时候在 u 的员工,否则找一个子树内能力值最小的一个替换掉(如果替换掉更优)。这个贪心正确性显然。我们考虑我们在做什么:实际上我们是维护了一个桶,设 a_i 是放在 i 上的员工的能力值,我们其实就是在 dfs 的过程中不断找当前子树的最小值然后更新。具体实现上,我们可以按照 dfs 序建立线段树,然后每次询问子树最小值,更新单点值。

那么加上加入和删除呢?加入我们就暴力加入,删除可以看作把一个员工的能力值将为 0 ,所以我们每次操作完后都进行一次询问。单次询问复杂度 O(n \log n) ,总体时间复杂度 O(n \log n+n \times m \log n) ,得分 48pts 。

code :

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define LL __int128
#define pii pair < int , int >
#define mp(i, j) make_pair(i, j)
#define l(o) (o << 1)
#define r(o) (o << 1 | 1)
#define mid (l + r >> 1)

const int N = 2e5 + 5, M = 262200;
int sid, n, k, m, v[N], dfn[N], e[N], tot;
vector < int > G[N], per[N];
ll sum;

int mn[M], pl[M];

void pushup(int o) { 
    mn[o] = min(mn[l(o)], mn[r(o)]); 
    pl[o] = mn[l(o)] < mn[r(o)] ? pl[l(o)] : pl[r(o)];
}

void build(int o, int l, int r) {
    mn[o] = 0, pl[o] = l;
    if(l == r) return;
    build(l(o), l, mid), build(r(o), mid + 1, r);
    pushup(o);
}

void add(int o, int l, int r, int p, int x) {
    if(l == r) {
        mn[o] = max(mn[o], x);
        return;
    }
    if(p <= mid) add(l(o), l, mid, p, x);
    else add(r(o), mid + 1, r, p, x);
    pushup(o);
}

pii querymin(int o, int l, int r, int ql, int qr) {
    if(ql <= l && r <= qr) return mp(mn[o], pl[o]);
    pii ans = mp(1e9, 0);
    if(ql <= mid) ans = min(ans, querymin(l(o), l, mid, ql, qr));
    if(qr > mid) ans = min(ans, querymin(r(o), mid + 1, r, ql, qr));
    return ans;
}

void getdfn(int u) {
    dfn[u] = ++ tot;
    for(int to : G[u]) getdfn(to);
    e[u] = tot;
}

void solve(int u) {
    for(int to : G[u]) solve(to);
    for(int p : per[u]) {
        pii res = querymin(1, 1, n, dfn[u], e[u]);
        int x = res.first, y = res.second;
        if(v[p] <= x) continue;
        sum -= x, sum += v[p];
        add(1, 1, n, y, v[p]);
    }
}

int main() {
    cin >> sid >> n >> k >> m;
    for(int i = 2, fa; i <= n; i++) {
        cin >> fa;
        G[fa].push_back(i);
    }
    for(int i = 1, p; i <= k; i++) {
        cin >> p >> v[i];
        per[p].push_back(i);
    }
    getdfn(1), build(1, 1, n), solve(1);
    cout << sum << " ";
    for(int i = 1; i <= m; i++) {
        build(1, 1, n);
        int op, x, y;
        cin >> op >> x;
        if(op == 1) {
            cin >> y;
            v[++ k] = y;
            per[x].push_back(k);
        } else v[x] = 0;
        sum = 0, solve(1);
        cout << sum << " ";
    } 
    return 0;
}

正解

实在是想不到啊,看了题解才会。

我们发现上述算法的瓶颈就是后续的 m 个操作。依据我们的贪心,我们不得不从下到上考虑,而这样的约束让我们在只有一个人这样非常少的信息改变的时候束手无策。容易想到不断全局查询是一个非常蠢的做法,因为我们做了大量的重复或者说不必要的计算。

所以我们需要发掘更好的性质。考虑特殊性质:没有第二种操作。如果我们只需要加入,那么会不会有什么很好的性质呢?

我们约定没有被放置的员工被称为被淘汰。一个非常显然的性质:如果一个人被淘汰了,那么再加入一些人他也不可能复活。这可以通过 48 pts 的构造方法得证。也就是我们只需要考虑现在分配到树上的人的取舍即可。

我们现在要在 u 上新增一个员工。

我们模拟贪心的过程,考虑 u 的 offer 是被谁抢了,一种可能就是被 u 子树内一个能力值没有 u 大的人给替了,还有就是被上面的人本不应该抢他的人给抢了。所以对于这两种情况我们进行讨论。如果是被后代抢了我们就暴力替换,否则就把上面的那个人给扔到 ta 的原始节点上。于是乎我们就变成了一个递归的问题。什么时候停止呢?当一个节点 x 使得 x 中的 offer 没有外援的时候就没有办法向上扔了。这时候就只能找一个子树内最小的进行替换。画图发现这个点 x 就是距离 u 最近的满足没有外援的祖先。

所以我们发现其实最后对答案的新贡献就是把新人加进去了,然后把一个最小值扔出去了。那么很明显只需要最小化这个最小值即可。这时候我们发现 x 显然就是我们可以取到最小值的这个祖先。所以我们每次只需要询问距离 u 的最近的没有外援的 x 并且询问 x 的子树最小值即可。

我们定义 size_x 是 x 的子树大小, s_x 是 x 子树中 offer 的原始节点都是 x 子树内的数量,那么 x 就是距离 u 最近的 size_x = s_x 的节点。故我们维护 size_x - s_x 。每次减去最小值以后就把最小值所在原始节点到根的路径的维护值都加上 1 。然后 u 到根减去 1 。可以用树链剖分处理。

还没完,因为我们都是在 size_x = s_x 的节点上进行查询最小值,所以我们的子树最小值就是所有以 x 子树内的节点为初始节点的人员的最小能力值。我们在每个点上维护一个 multiset 即可。线段树维护最小值。

至于删除,我们使用线段树分治,把所有的上述操作反着操作一遍就可以撤销了。

注意当 x 不存在时需要特判。可以证明,当不存在 x 的时候一定可以通过某些操作使得 u 放置是 0 代价的。 query 函数的剪枝非常重要。

code:

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define LL __int128
#define pii pair < int , int >
#define mp(i, j) make_pair(i, j)
#define size Size
#define l(o) (o << 1)
#define r(o) (o << 1 | 1)
#define mid (l + r >> 1)

const int N = 2e5 + 5, M = 4e5 + 5;
const int INF = 1e9;
int sid, n, m, k, las[N], p[N], v[N];
ll Ans[N];
int dfn[N], e[N], rk[N], tot, size[N], son[N], top[N], fa[N];
vector < int > G[N], vec[M];
multiset < int > st[N];

void dfs(int u) {
    size[u] = 1;
    for(int to : G[u]) {
        dfs(to);
        size[u] += size[to];
        if(size[to] >= size[son[u]]) son[u] = to;
    }
}

void dfs1(int u, int x) {
    top[u] = x;
    dfn[u] = ++ tot, rk[tot] = u;
    if(son[u]) dfs1(son[u], x);
    for(int to : G[u]) 
        if(to != son[u]) dfs1(to, to);
    e[u] = tot;
}

int sum[M], tag[M]; pii mn[M];
// sum = min(size - s)
// mn -> min(val), real position

void pushup(int o) {
    sum[o] = min(sum[l(o)], sum[r(o)]);
    mn[o] = min(mn[l(o)], mn[r(o)]);
}

void pushdown(int o) {
    if(!tag[o]) return;
    tag[l(o)] += tag[o], tag[r(o)] += tag[o];
    sum[l(o)] += tag[o], sum[r(o)] += tag[o];
    tag[o] = 0;
}

void build(int o, int l, int r) {
    if(l == r) {
        sum[o] = size[rk[l]], mn[o] = mp(INF, rk[l]); // ?
        return;
    }   
    build(l(o), l, mid), build(r(o), mid + 1, r);
    pushup(o);
}

void cover(int o, int l, int r, int pos, int x) { // cover min(val)
    if(l == r) {
        mn[o] = mp(x, rk[l]);
        return;
    }
    pushdown(o);
    if(pos <= mid) cover(l(o), l, mid, pos, x);
    else cover(r(o), mid + 1, r, pos, x);
    pushup(o);
}

pii querymin(int o, int l, int r, int ql, int qr) { // query min(val, real)
    if(ql <= l && r <= qr) return mn[o];
    pushdown(o);
    pii ans = mp(INF, 0);
    if(ql <= mid) ans = min(ans, querymin(l(o), l, mid, ql, qr));
    if(qr > mid) ans = min(ans, querymin(r(o), mid + 1, r, ql, qr));
    return ans;
}

void add(int o, int l, int r, int ql, int qr, int x) { // sum + x
    if(ql <= l && r <= qr) {
        sum[o] += x, tag[o] += x;
        return;
    }
    pushdown(o);
    if(ql <= mid) add(l(o), l, mid, ql, qr, x);
    if(qr > mid) add(r(o), mid + 1, r, ql, qr, x);
    pushup(o);
}

int query(int o, int l, int r, int ql, int qr) { // query min(sum)
    if(ql <= l && r <= qr) return sum[o];
    pushdown(o);
    int ans = INF;
    if(ql <= mid) ans = min(ans, query(l(o), l, mid, ql, qr));
    if(qr > mid) ans = min(ans, query(r(o), mid + 1, r, ql, qr));
    return ans;
}

int find0(int o, int l, int r, int ql, int qr) { // real position
    // right  
    if(l == r) return sum[o] == 0 ? rk[l] : 0;
    if(sum[o] > 0) return 0; 
    pushdown(o); 
    int tmp = (qr > mid ? find0(r(o), mid + 1, r, ql, qr) : 0 );
    if(tmp) return tmp;
    return ql <= mid ? find0(l(o), l, mid, ql, qr) : 0;
}

void addtree(int u, int x) {
    while(u) {
        add(1, 1, n, dfn[top[u]], dfn[u], x);
        u = fa[top[u]];
    }
}

int query0(int u) {
    while(u) { // not !query
        if(query(1, 1, n, dfn[top[u]], dfn[u])) u = fa[top[u]];
        else return find0(1, 1, n, dfn[top[u]], dfn[u]);
    }
    return 0;
}

ll ans = 0;

void addt(int o, int l, int r, int ql, int qr, int x) {
    if(ql <= l && r <= qr) {
        vec[o].push_back(x);
        return;
    }
    if(ql <= mid) addt(l(o), l, mid, ql, qr, x);
    if(qr > mid) addt(r(o), mid + 1, r, ql, qr, x);
}

void insert(int x, int y) {
    if(!x || !y) return;
    st[x].insert(y);
    addtree(x, -1);
    cover(1, 1, n, dfn[x], *st[x].begin());
}

void remove(int x, int y) {
    if(!x || !y) return;
    st[x].erase(st[x].lower_bound(y));
    addtree(x, 1);
    cover(1, 1, n, dfn[x], st[x].empty() ? INF : *st[x].begin());
}

struct node {
    int ux, uy, vx, vy;
};

void solve(int o, int l, int r) {
    vector < node > U;
    for(int u : vec[o]) {
        int x = query0(p[u]);
        if(!x) { // 特判 x = 0 
            U.push_back((node) { p[u], v[u], 0, 0 });
            ans += v[u], insert(p[u], v[u]);
            continue;
        }
        pii res = querymin(1, 1, n, dfn[x], e[x]);
        if(res.first >= v[u]) continue;
        U.push_back((node) { p[u], v[u], res.second, res.first });
        ans += v[u], ans -= res.first;
        insert(p[u], v[u]), remove(res.second, res.first);
    }
    if(l == r) Ans[l] = ans;
    else solve(l(o), l, mid), solve(r(o), mid + 1, r);
    reverse(U.begin(), U.end());
    for(node res : U) {
        ans -= res.uy, ans += res.vy;
        remove(res.ux, res.uy), insert(res.vx, res.vy);
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin >> sid >> n >> k >> m;
    for(int i = 2; i <= n; i++) {
        cin >> fa[i];
        G[fa[i]].push_back(i);
    }
    dfs(1), dfs1(1, 1);
    for(int i = 1; i <= k; i++) cin >> p[i] >> v[i];
    for(int i = 1; i <= m; i++) {
        int op, x;
        cin >> op >> x;
        if(op == 1) {
            k ++;
            p[k] = x, cin >> v[k];
            las[k] = i;
        } else {
            addt(1, 0, m, las[x], i - 1, x);
            las[x] = m + 1;
        }
    }
    for(int i = 1; i <= k; i++) 
        if(las[i] <= m) addt(1, 0, m, las[i], m, i);
    build(1, 1, n);
    solve(1, 0, m);
    for(int i = 0; i <= m; i++) cout << Ans[i] << " ";
    return 0;
}