P13020 [GESP202506 八级] 遍历计数 换根 DP 题解

· · 题解

前排提示:本题解包含多种做法~

题意简述

原题传送门。给出一棵 n 个节点的无根树。求树的 dfs 序可能数之和(树根不定),答案对 10^9 取模。

思路解析

具体实现细节见代码注释。

换根 DP + 线段树

本人赛时做法,相当复杂。

首先,容易想到这是一道树形 DP。我们先求以任意一个节点为根时树的 dfs 序可能个数之和。我们在处理以节点 u 为根的子树内 dfs 序可能数之和时,先对 u 的子节点集合 Son_u 进行全排列,再递归计算每个子节点的可能数,于是有转移方程

dp_u=cnt_u!\times\prod_{v\in Son_u}dp_v

其中 cnt_u 表示 u 的子节点个数。

然而我们的根不止一个,所以我们进行换根 DP 即可。但是你会发现我们将根节点 u 换成 v 时,新形态的子节点 u 转移方程会变成

dp'_u=(cnt_u-1)!\times\prod_{(v'\in Son_u\land v'\neq v)\lor (v'==f_u)}dp_{v'}

父节点 v 的则是

dp'_v=(cnt_v+1)!\times\prod_{v'\in Son_v\lor v'==u}dp_{v'}

而我们的模数又是 10^9,对于 dp'_u 来说不方便直接使用逆元去除 dp_v 的影响。于是我们想到使用线段树来查询。

线段树也不难维护。我们先将节点重新编号为其 bfs 序,这样所有的查询就变得连续了。再维护出每个节点 u 的子节点中最小的编号 s_u 和最大的编号 t_u 作为查询的左右端点即可。

这种做法的总时间复杂度为 O(nlogn),可以通过此题。

换根 DP (无线段树)

可能有的聪明人已经注意到了,线段树的部分完全可以使用前后缀积来写(而赛时的我并没有……)。

时空复杂度 O(n)。个人觉得比数据结构难调……毕竟要动脑调试细节嘛~

数学方法

不过根据赛后大家单方向传授讨论的结果,我们还有代码更简洁,常数更优秀的 O(n) 做法~

观察我们最开始的转移方程

dp_u=cnt_u!\times\prod_{v\in Son_u}dp_v

我们手动递归一下可以惊奇地发现(Subtree_u 表示以 u 为根的子树内节点的集合)

dp_u=\prod_{v\in Subtree_u}cnt_v!

我们再用节点 u 的度 deg_u 来改写。

deg_u=\begin{cases}cnt_u&u==rt\\cnt_u-1&else\end{cases}

所以我们可以直接写出(Tree 表示树内所有节点)

dp_{rt}=deg_{rt}!\prod_{u\in Tree\land u\neq rt}(deg_u-1)!=deg_{rt}\prod_{u\in Tree}(deg_u-1)!

这样我们就可以快速求出以每个节点为根时答案的总和啦~

ans=\sum_{u\in Tree} deg_u\prod_{v\in Tree}(deg_v-1)!=(\sum_{u\in Tree} deg_u)\prod_{u\in Tree}(deg_u-1)!=2(n-1)\prod_{u\in Tree}(deg_u-1)!

具体来说,先预处理出阶乘数组,然后就可以 O(n) 求出结果啦,灰常简单~ 千万别忘了要在 n=1 时特判输出 1 就行。

赛时没多推推式子,一步之差,AC 时间差了一个小时……数据结构敲多了导致的~

代码示例

禁止 ctrl+C&V (代码已进行防无脑抄袭处理~)

注意取模的处理(为了方便的可以直接开long long)

换根 DP + 线段树

#include<bits/stdc++.h>
using namespace std;
constexpr int N = 1e5, mod = 1e9;
int n;
int fct[N]; // fct[i]表示i的阶乘
vector<int> g[N];
int f[N]; // f[i]表示i节点的父亲节点
int id[N]; // id[i]表示i节点的bfs序
int rv[N]; // rv[i]表示bfs序为i的节点
void bfs(const int s) {
    queue<int> q;
    q.push(s);
    int tim = 0;
    while(!q.empty()) {
        const int u = q.front(); q.pop();
        rv[id[u] = ++tim] = u;
        for(const int v : g[u]) if(v != f[u])
            f[v] = u, q.push(v);
    }
}
int s[N], t[N]; // [s[i],t[i]]表示i节点的子节点的bfs序区间
int cnt[N]; // cnt[i]表示i节点的子节点个数
int dp[N]; // dp[i]表示以i节点为根的子树内dfs序可能个数之和
void dfs1(const int u) {
    dp[u] = 1;
    for(const int v : g[u]) if(v != f[u]) {
        ++cnt[u];
        if(!s[u]) s[u] = id[v];
        t[u] = id[v];
        dfs1(v);
        dp[u] = dp[u] * dp[v] % mod;
    }
    dp[u] = dp[u] * fct[cnt[u]] % mod;
}
struct segment_tree {
    int tr[N]; // 求的是积
    #define li i << 1
    #define ri i << 1 | 1
    void build(const int i, const int ll, const int rr) {
        if(ll == rr) return void(tr[i] = dp[rv[ll]]); // 注意此处的细节!!!
        const int mid = (ll + rr) >> 1;
        build(li, ll, mid), build(ri, mid + 1, rr);
        tr[i] = tr[li] * tr[ri];
    }
    int query(const int i, const int ll, const int rr, const int l, const int r) {
        if(l <= ll && rr <= r) return tr[i];
        const int mid = (ll + rr) >> 1;
        if(r <= mid) return query(li, ll, mid, l, r);
        if(l > mid) return query(ri, mid + 1, rr, l, r);
        return query(li, ll, mid, l, r) * query(ri, mid + 1, rr, l, r);
    }
    #undef li
    #undef ri
} Tr;
// 查询u节点所有子节点的dp[]的乘积
int query(const int u) {
    if(!s[u]) return 1;
    return Tr.query(1, 1, n, s[u], t[u]);
}
// 查询u节点除了v节点外所有子节点的dp[]的乘积
int query(const int u, const int v) {
    // 注意特判,否则TLE!!!(为了方便的可以直接在segment_tree::query()函数中特判)
    if(id[v] == s[u]) return Tr.query(1, 1, n, s[u] + 1, t[u]);
    return Tr.query(1, 1, n, s[u], id[v] - 1) * Tr.query(1, 1, n, id[v] + 1, t[u]);
}
int ans;
void dfs2(const int u) {
    if((ans += dp[u]) >= mod) ans -= mod;
    for(const int v : g[u]) if(v != f[u]) {
        const int dpu = dp[u], dpv = dp[v]; // 记录当前的状态,方便回溯
        dp[u] = query(u, v) * dp[f[u]] * fct[--cnt[u]] % mod;
        dp[v] = query(v) * dp[u] * fct[++cnt[v]] % mod;
        dfs2(v);
        ++cnt[u], --cnt[v], dp[u] = dpu, dp[v] = dpv;
    }
}
int main() {
    ios::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr);
    cin >>n;
    fct[0] = 1;
    for(int i = 1; i <= n; ++i) fct[i] = fct[i - 1] * i % mod;
    for(int i = 2, u, v; i <= n; ++i) cin >>u >>v,
        g[u].push_back(v), g[v].push_back(u);
    bfs(1);
    dfs1(1);
    Tr.build(1, 1, n);
    dp[0] = 1; // 为了dfs2(1)时dp[f[1]](即dp[0])不出错
    dfs2(1);
    cout <<ans <<'\n';
    return 0;
}

换根 DP (无线段树)

#include<bits/stdc++.h>
using namespace std;
constexpr int N = 1e5, mod = 1e9;
int n;
int fct[N];
vector<int> g[N];
int f[N];
int cnt[N]; // cnt[i]表示i节点的子节点个数
int ff[N]; // ff[i]表示i节点所有子节点dp[]值的乘积 (新增
int dp[N]; // dp[i]表示以i节点为根的子树内dfs序可能个数之和
void dfs1(const int u) {
    ff[u] = 1;
    for(const int v : g[u]) if(v != f[u]) {
        f[v] = u, ++cnt[u], dfs1(v);
        ff[u] = ff[u] * dp[v] % mod;
    }
    dp[u] = ff[u] * fct[cnt[u]] % mod;
}
int suf[N]; // suf[i]表示i节点的后缀积
int ans;
void dfs2(const int u) { // 细节挺多的……特别是g[u]中父亲的影响!!!
    if((ans += dp[u]) >= mod) ans -= mod;
    if(!cnt[u]) return; // 是叶子结点就直接返回
    int lst = g[u].back() == f[u] ? g[u][g[u].size() - 2] : g[u].back(); // 后缀节点
    suf[lst] = 1;
    for(int i = g[u].size() - 2; i >= 0; --i) {
        const int v = g[u][i];
        if(v != f[u]) suf[v] = 1LL * dp[lst] * suf[lst] % mod, lst = v;
    }
    int pre = 1;
    for(const int v : g[u]) if(v != f[u]) {
        const int dpu = dp[u], dpv = dp[v];
        dp[u] = pre * suf[v] * dp[f[u]] * fct[--cnt[u]] % mod;
        dp[v] = ff[v] * dp[u] * fct[++cnt[v]] % mod;
        dfs2(v);
        ++cnt[u], --cnt[v], dp[u] = dpu, dp[v] = dpv;
        pre = pre * dp[v] % mod;
    }
}
int main() {
    ios::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr);
    cin >>n;
    fct[0] = 1;
    for(int i = 1; i <= n; ++i) fct[i] = 1LL * fct[i - 1] * i % mod;
    for(int i = 2, u, v; i <= n; ++i) cin >>u >>v,
        g[u].push_back(v), g[v].push_back(u);
    dfs1(1);
    dp[0] = 1; // 为了dfs2(1)时dp[f[1]](即dp[0])不出错
    dfs2(1);
    cout <<ans <<'\n';
    return 0;
}

数学方法

#include<bits/stdc++.h>
using namespace std;
constexpr int N = 1e5, mod = 1e9;
int n;
int fct[N];
int deg[N];
int main() {
    ios::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr);
    cin >>n;
    if(n == 1) return cout <<"1\n", 0; // 记得特判!!!
    fct[0] = 1;
    for(int i = 1; i <= n; ++i) fct[i] = fct[i - 1] * i % mod;
    for(int i = 2, u, v; i <= n; ++i) cin >>u >>v, ++deg[u], ++deg[v];
    int ans = 1;
    for(int i = 1; i <= n; ++i) ans = ans * fct[deg[i] - 1] % mod;
    cout <<2 * (n - 1) * ans % mod <<'\n';
    return 0;
}

后记

考前本以为能半小时切掉的题,硬控了我一个半小时……一开始感觉能评上位绿甚至蓝,现在再回看感觉只能是绿甚至黄。

本蒟蒻的第一篇题解!如有不好之处望各位大佬多多指正~

无论如何,祝各位 CSP2025 比赛顺利,RP++,斩获佳绩!!!待到金秋时分,我们赛场上不见不散~