P14637 题解

· · 题解

符号与约定:

考虑 x 点的权值会填什么,容易发现 x 一定是填 x \to 1 这条链上某个点的 \mathrm{mex}

这样就能写出 O(n^3) 暴力,设 dp[x, i, j] 表示 x 子树内 \mathrm{mex} = i,有 j 个点还没决定,子树最大和是多少。这 j 个没决定的点是留着,等以后填 \mathrm{mex} 用。

转移很简单,先取子树内的点的 \mathrm{mex} 最大值,然后把 j 加起来,最后看填几个在 x 上就行了。

考虑优化,可以把 \mathrm{mex} 取最大值的那个儿子特殊考虑,整体的结构类似树剖,x 的重儿子是 y,当且仅当 x\mathrm{mex} 是从 y 继承的。

于是假定现在剖分的方式已定,怎么算答案?对于一个点 x,看 x \to 1 这条链与哪条重链交集最大,贡献就是交集大小。

这样就可以设出新的 DP 状态:dp[x, i, j] 表示 x 子树内,现在 x 到根最长的重链长 i,目前 x 到链顶的重链长 j,子树内的点的最大贡献和。

转移就是:对于 x,枚举 i, j 和它的重儿子 y,则答案为 i + dp[y, \max(i, j + 1), j + 1] + \sum_{z \in son_x} [z \ne y] dp[z, i, 1]

稍经优化可以做到复杂度 O(n m^2),但还需要进一步优化。

d_x 表示 \max_{y \in \mathrm{subtree}_x} \mathrm{dep}_y - \mathrm{dep}_x。观察到当 i - j \ge d_x 时,dp[x, i, j] 一定是 \mathrm{size}_x \times i,因为这时 j 没有可能超过 i。于是第三维换成 i - j,这样就能写出一个比较有优化前途的代码(贴这里主要方便对照下面的转移):

:::info[O(nm^2) 的代码]

#include <bits/stdc++.h>

using i64 = long long;
using namespace std;

constexpr int N = 4E3 + 5, M = 55, NN = 365;

int n, m;
vector<int> adj[N];
vector<vector<vector<int>>> dp;

void dfs(int x) {
    for (auto y : adj[x]) {
        dfs(y);
    }

    if (adj[x].empty()) {
        for (int i = 1; i <= m; ++i) {
            for (int j = 0; j < i; ++j) {
                dp[x][i][j] = i;
            }
        }
        return ;
    }

    for (int i = 1; i <= m; ++i) {
        for (int j = 0; j < i; ++j) {
            int sum = 0;
            for (auto y : adj[x]) {
                sum += dp[y][i][i - 1];
            }
            for (auto y : adj[x]) {
                dp[x][i][j] = max(dp[x][i][j], dp[y][i + !j][max(0, j - 1)] - dp[y][i][i - 1] + sum);
            }
            dp[x][i][j] += i;
        }
    }
}

void solve() {
    cin >> n >> m;
    m++;

    if (n <= 360) {
        dp.assign(NN, vector(NN, vector<int>(NN, 0)));
    } else {
        dp.assign(N, vector(M, vector<int>(M, 0)));
    }

    for (int i = 1; i <= n; ++i) {
        adj[i].clear();
    }

    for (int i = 2; i <= n; ++i) {
        int f;
        cin >> f;
        adj[f].push_back(i);
    }

    dfs(1);

    cout << dp[1][1][0] << "\n";
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int t;
    cin >> t;

    while (t--) {
        solve();
    }

    return 0;
}

:::

由于第三维是 O(d_x) 的,尝试用长剖优化,设 x 的长儿子为 l_x。具体地,考虑转移:

总复杂度 O(nm)

:::info[O(nm) 的代码]

#include <bits/stdc++.h>

using i64 = long long;
using namespace std;

constexpr int N = 8005, M = 805;

int n, m;
int d[N], l[N], s[N], add[N][M];
int head[N], nxt[N], to[N], tot;
int it[N][M], pool[N * M], pp;

void pre(int x) {
    s[x] = 1;
    l[x] = 0;
    for (int i = head[x]; i; i = nxt[i]) {
        int y = to[i];
        pre(y);
        if (d[y] >= d[l[x]]) {
            l[x] = y;
        }
        s[x] += s[y];
    }
    d[x] = d[l[x]] + 1;
}

inline int& DP(int x, int i, int j) {
    return pool[it[x][i] + j];
}
inline int getDP(int x, int i, int j) {
    if (j >= d[x]) {
        return i * s[x];
    }
    if (i > m) {
        return 0;
    }
    return pool[it[x][i] + j] + add[x][i];
}

void dfs(int x) {
    for (int i = 1; i <= m; ++i) {
        add[x][i] = 0;
    }

    if (!l[x]) {
        for (int i = 1; i <= m; ++i) {
            DP(x, i, 0) = i;
        }
        return;
    }

    for (int i = 1; i <= m; ++i) {
        it[l[x]][i] = it[x][i] + 1;
    }
    dfs(l[x]);

    for (int i = 1; i <= m; ++i) {
        add[x][i] = add[l[x]][i];
        DP(x, i, 0) = getDP(l[x], i + 1, 0) - add[x][i];
    }

    int ls = pp;

    for (int i = head[x]; i; i = nxt[i]) {
        int y = to[i];
        if (y != l[x]) {
            for (int j = 1; j <= m; ++j) {
                it[y][j] = pp;
                pp += d[y];
            }
            memset(pool + it[y][1], 0, m * d[y] * 4);
            dfs(y);
        }
    }

    for (int i = 1; i <= m; ++i) {
        int S = 0;
        for (int j = head[x]; j; j = nxt[j]) {
            S += getDP(to[j], i, i - 1);
        }
        add[x][i] += S - getDP(l[x], i, i - 1);

        for (int j = head[x]; j; j = nxt[j]) {
            int y = to[j];
            if (y == l[x]) {
                continue;
            }
            for (int k = 0; k <= d[y]; ++k) {
                DP(x, i, k) = max(DP(x, i, k), getDP(y, i + !k, k - !!k) - getDP(y, i, i - 1) + S - add[x][i]);
            }
        }
        add[x][i] += i;
    }

    pp = ls;
}

void solve() {
    cin >> n >> m;
    m++;

    tot = 0;
    for (int i = 1; i <= n; ++i) {
        head[i] = 0;
    }

    for (int i = 2; i <= n; ++i) {
        int f;
        cin >> f;
        to[++tot] = i;
        nxt[tot] = head[f];
        head[f] = tot;
    }
    pre(1);

    pp = 0;
    for (int i = 1; i <= m; ++i) {
        it[1][i] = pp;
        pp += d[1];
    }
    memset(pool, 0, m * d[1] * 4);
    dfs(1);

    cout << getDP(1, 1, 0) << "\n";
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int t;
    cin >> t;

    while (t--) {
        solve();
    }

    return 0;
}

:::