题解:P1411 树

· · 题解

思路

树形动态规划(Tree DP)

状态定义:

定义 dp[u][k] 表示以节点 u 为根的子树中,删除若干边后,u 所在的连通块大小为 k 时,所有连通块乘积的最大值(实际存储其对数值以避免数值过大)。

状态转移:

对于节点 u 的每个子节点 v,有两种选择:

断开边 (u,v):此时 v 的子树成为一个独立的连通块,大小为 size_v,贡献为 dp[v][size_v] + \log(size_v)

保留边 (u,v):将 v 的子树合并到 u 的连通块中,新的连通块大小为 k + size_v,贡献为 dp[v][size_v]

初始条件:

对于叶子节点 udp[u][1] = 0(连通块大小为1,乘积为1,对数为0)。

结果计算:

最终答案为 \max(dp[1][k] + \log(k)) 对所有 k 的最大值,然后通过高精度计算实际乘积。

代码(请勿抄袭!)

#include <iostream>
#include <vector>
#include <cmath>
#include <tuple>
#include <algorithm>
#include <cstring>
using namespace std;

const int MAXN = 705;

int n;
vector<int> G[MAXN];
vector<int> children[MAXN];
int sz[MAXN];
double log_j[MAXN];

struct BigInt {
    vector<int> digits;
    BigInt(long long x = 0) {
        if (x == 0) {
            digits.push_back(0);
            return;
        }
        while (x) {
            digits.push_back(x % 10);
            x /= 10;
        }
    }
    BigInt(const vector<int>& d) : digits(d) {}
    BigInt operator*(int x) const {
        if (x == 0) return BigInt(0);
        vector<int> res;
        int carry = 0;
        for (int i = 0; i < digits.size(); ++i) {
            carry += digits[i] * x;
            res.push_back(carry % 10);
            carry /= 10;
        }
        while (carry) {
            res.push_back(carry % 10);
            carry /= 10;
        }
        return BigInt(res);
    }
    void output() const {
        if (digits.size() == 0) {
            cout << 0 << '\n';
            return;
        }
        for (int i = digits.size()-1; i >= 0; --i)
            cout << digits[i];
        cout << '\n';
    }
};

vector<double> dp[MAXN];
vector< vector< vector< tuple<int, int, bool> > > > choices;

void dfs_size(int u, int parent) {
    sz[u] = 1;
    for (int v : G[u]) {
        if (v == parent) continue;
        dfs_size(v, u);
        sz[u] += sz[v];
        children[u].push_back(v);
    }
}

void dfs_dp(int u) {
    dp[u] = vector<double>(2, -1e18);
    dp[u][1] = 0.0;
    int current_size = 1;

    vector< vector< tuple<int, int, bool> > > choices_u;

    for (int v : children[u]) {
        dfs_dp(v);

        vector<double> new_dp(current_size + sz[v] + 1, -1e18);
        vector< tuple<int, int, bool> > new_choice(current_size + sz[v] + 1, make_tuple(-1, -1, false));

        for (int k_old = 1; k_old <= current_size; k_old++) {
            if (dp[u][k_old] < -1e17) continue;
            for (int j = 1; j <= sz[v]; j++) {
                if (dp[v][j] < -1e17) continue;

                double cut_val = dp[u][k_old] + dp[v][j] + log_j[j];
                if (cut_val > new_dp[k_old]) {
                    new_dp[k_old] = cut_val;
                    new_choice[k_old] = make_tuple(k_old, j, true);
                }

                int new_state = k_old + j;
                double not_cut_val = dp[u][k_old] + dp[v][j];
                if (not_cut_val > new_dp[new_state]) {
                    new_dp[new_state] = not_cut_val;
                    new_choice[new_state] = make_tuple(k_old, j, false);
                }
            }
        }

        choices_u.push_back(new_choice);
        dp[u] = new_dp;
        current_size += sz[v];
    }

    choices[u] = choices_u;
}

vector<int> get_components(int u, int k) {
    vector<int> comps;
    int num_children = children[u].size();
    int cur_state = k;
    for (int i = num_children-1; i >= 0; i--) {
        int v = children[u][i];
        tuple<int, int, bool> tr = choices[u][i][cur_state];
        int prev_state = get<0>(tr);
        int child_state = get<1>(tr);
        bool cut = get<2>(tr);
        vector<int> child_comps = get_components(v, child_state);
        if (cut) {
            comps.push_back(child_state);
        }
        comps.insert(comps.end(), child_comps.begin(), child_comps.end());
        cur_state = prev_state;
    }
    return comps;
}

int main() {
    cin >> n;
    for (int i = 1; i <= n; i++) {
        log_j[i] = log(i);
    }

    for (int i = 0; i < n-1; i++) {
        int u, v;
        cin >> u >> v;
        G[u].push_back(v);
        G[v].push_back(u);
    }

    dfs_size(1, -1);
    choices.resize(n+1);
    dfs_dp(1);

    double best_log = -1e18;
    int best_k = -1;
    for (int k = 1; k <= sz[1]; k++) {
        if (dp[1][k] < -1e17) continue;
        double candidate = dp[1][k] + log_j[k];
        if (candidate > best_log) {
            best_log = candidate;
            best_k = k;
        }
    }

    vector<int> comps = get_components(1, best_k);
    comps.push_back(best_k);

    BigInt product(1);
    for (int x : comps) {
        product = product * x;
    }
    product.output();

    return 0;
}