P11363 [NOIP2024] 树的遍历
xiaosi4081 · · 题解
树的形态与遍历顺序无关,所以先要固定遍历顺序。
原树只要先访问了上面的一条边再访问下面一条边,就会要先访问下面这条边对应子树,然后回到上面那条边,再访问上面这条边同级的其他边。如图:
遍历到一条边后有两种可能:
- 向子树里面跑再回溯往同级边跑
- 去同级边再回溯往子树里跑
不需要考虑遍历顺序,所以,我们先往子树里跑,再往同级边跑,一条边就可以看成边+子树,
接下来就是
对于任意一点,其邻边构成一条链,那么可能的根节点就是某两条边(一条边)或者往这两条边的方向走是根。这样一颗树就对应着可能是根的一条叶子到叶子的链(双射关系)。
于是对链计数,对于固定的一条链,设其上面的节点集合为
这个被计入答案必须满足链上存在至少一个给定起点边。答案取和,所以前一项可以提出来,对后一项进行计数。
然后就可以 dp 了。
问题的核心在于找到一棵树可能的根节点构成一条从叶子到叶子的链,然后直接对链计数。再核心一点,就是找到能表示这棵树的双射关系然后计数,这个关系可以泛一点。
#include<bits/stdc++.h>
#define int long long
#define mod 1000000007
using namespace std;
const int N = 1e5+5;
int c, T, n, k, head[N];
struct edge{
int v, w, nxt;
} e[N*2];
int ecnt, flg[N];
void adde(int u, int v, int i){
e[++ecnt] = {v, i, head[u]};
head[u] = ecnt;
}
int s[N][2], vl[N], du[N], res;
void dfs(int u, int fa){
s[u][0] = s[u][1] = 0;
for(int i = head[u], v; i; i = e[i].nxt){
v = e[i].v;
if(v == fa) continue;
dfs(v, u);
if(flg[e[i].w]) s[u][1] += s[v][0]+s[v][1];
else{
s[u][0] += s[v][0];
s[u][1] += s[v][1];
}
}
s[u][0] %= mod, s[u][1] %= mod;
if(du[u] > 1){
for(int i = head[u], v, w; i; i = e[i].nxt){
v = e[i].v, w = flg[e[i].w];
if(v == fa) continue;
int sv0 = s[v][0]*(1-w), sv1 = s[v][1]+w*s[v][0];
res += sv1*(s[u][0]-sv0)%mod*vl[u]%mod;
res += sv0*(s[u][1]-sv1)%mod*vl[u]%mod;
res += sv1*(s[u][1]-sv1)%mod*vl[u]%mod;
res %= mod;
}
}
s[u][0] = s[u][0]%mod*vl[u]%mod;
s[u][1] = s[u][1]%mod*vl[u]%mod;
if(du[u] == 1) s[u][0] = 1;
}
int qpow(int a, int b){
if(a == 0) return 1;
int res = 1;
while(b){
if(b&1) res = res*a%mod;
a = a*a%mod;
b >>= 1;
} return res;
}
int fac[N];
signed main(){
fac[0] = 1;
for(int i = 1; i < N; i++) fac[i] = fac[i-1]*i%mod;
cin >> c >> T;
while(T--){
cin >> n >> k, ecnt = 0, res = 0;
for(int i = 1; i <= n; i++) head[i] = du[i] = flg[i] = 0;
for(int i = 1; i < n; i++){
int u, v; cin >> u >> v;
adde(u, v, i), adde(v, u, i);
du[u]++, du[v]++;
}
for(int i = 1; i <= k; i++){
int x; cin >> x;
flg[x] = 1;
}
if(n == 2){
cout << "1\n";
continue;
}
for(int i = 1; i <= n; i++) vl[i] = qpow(du[i]-1, mod-2);
for(int i = 1; i <= n; i++){
if(du[i] != 1){
dfs(i, 0);
break;
}
} res *= qpow(2, mod-2); res %= mod;
for(int i = 1; i <= n; i++) res = res*fac[du[i]-1]%mod;
res = (res+mod)%mod;
cout << res << endl;
}
return 0;
}