考虑找到树的重心,这样子每个节点都是小于 \log n 的大小,这样子就可以继续递归,保证最后是 O(n\log n)。
首先,找出树的重心。
然后,分两种情况讨论:经过重心和不经过重心。
然后不经过重心的就继续分治就好了。
例如,b-e 就是经过 a,但是 d-e 就没有经过,对于这种情况,我们要把 c 作为子树然后再求。
找重心代码:
void GetRoot( int x , int fa ) {
f[ x ] = 0 , siz[ x ] = 1;
for( int i = head[ x ] ; i ; i = e[ i ].nxt ) {
int y = e[ i ].v;
if( vis[ y ] || y == fa ) continue; //vis是否当过根节点
GetRoot( y , x );
siz[ x ] += siz[ y ];
f[ x ] = max( f[ x ] , siz[ y ] ); //求最大子树大小
}
f[ x ] = max( f[ x ] , sum - siz[ x ] ); //剩下那一部分也要算上
if( f[ x ] < f[ root ] ) root = x; //寻找最小的
}
然后我们用树的重心跑出子树每个的深度值,放在 depth 数组里。
参考代码:
void GetDep( int x , int fa , int sum ) {
dep[ ++tot ] = sum;
for( int i = head[ x ] ; i ; i = e[ i ].nxt ) {
int y = e[ i ].v;
if( vis[ y ] || y == fa ) continue;
GetDep( y , x , sum + e[ i ].w );
}
}
然后将 depth 数组排序,我们就知道每个的距离了。
如果枚举每一个节点,那时间复杂度会增加。
所以要排序,然后通过双指针求出个数。
例如这个数组:
0 1 2 2 3
k=4
现在 l=1,r=5。
$a_2+a_5=4$,答案数加 $5-2=3$。
$a_3+a_5>4$,超过了,$r$ 减一。
$a_3+a_4=4$,答案数加 $4-3=1$。
然后就搞定了。
注意,如果两个不经过重心,如 $d-e$,路径 $d-c-a-c-e$ 还是 $\le k$ 的话,我们就要把这个答案减掉,等到子树重心为 $c$ 时才算入答案。
完整代码:
```.cpp
#include <bits/stdc++.h>
using namespace std;
const int _ = 4e4 + 5;
inline int Read() {
int x = 0 , f = 1;
char c = getchar();
for( ; c < '0' || c > '9' ; c = getchar() ) f ^= ( c == '-' );
for( ; c >= '0' && c <= '9' ; c = getchar() ) x = ( x << 3 ) + ( x << 1 ) + ( c ^ 48 );
return f ? x : -x;
}
struct Edge { int v , w , nxt ; } e[_*2];
int head[_] , ecnt;
void Add( int u , int v , int w ) { e[ ++ecnt ] = Edge{ v , w , head[ u ] } ; head[ u ] = ecnt ; }
int n , k , root , sum , tot , ans;
int f[_] , siz[_] , dep[_] , vis[_];
void GetRoot( int x , int fa ) {
f[ x ] = 0 , siz[ x ] = 1;
for( int i = head[ x ] ; i ; i = e[ i ].nxt ) {
int y = e[ i ].v;
if( vis[ y ] || y == fa ) continue;
GetRoot( y , x );
siz[ x ] += siz[ y ];
f[ x ] = max( f[ x ] , siz[ y ] );
}
f[ x ] = max( f[ x ] , sum - siz[ x ] );
if( f[ x ] < f[ root ] ) root = x;
}
void GetDep( int x , int fa , int l ) {
dep[ ++tot ] = l;
for( int i = head[ x ] ; i ; i = e[ i ].nxt ) {
int y = e[ i ].v;
if( vis[ y ] || y == fa ) continue;
GetDep( y , x , l + e[ i ].w );
}
}
int Calc( int x , int L , int ans = 0 ) {
tot = 0;
GetDep( x , 0 , L );
sort( dep + 1 , dep + 1 + tot );
int l = 1 , r = tot;
while( l < r ) { //排序双指针法求答案
if( dep[ l ] + dep[ r ] <= k ) {
ans += r - l;
l++;
} else r--;
}
return ans;
}
void Devide( int x ) { //分治
vis[ x ] = 1;
ans += Calc( x , 0 );
for( int i = head[ x ] ; i ; i = e[ i ].nxt ) {
int y = e[ i ].v;
if( vis[ y ] ) continue;
ans -= Calc( y , e[ i ].w ); //减去重复的,当前总共为e[i].w
root = 0 , sum = siz[ y ];
GetRoot( y , 0 );
Devide( root ); //继续分治
}
}
int main() {
n = Read();
sum = f[ 0 ] = n;
for( int i = 1 ; i < n ; i++ ) {
int u = Read() , v = Read() , w = Read();
Add( u , v , w ) , Add( v , u , w );
}
k = Read();
GetRoot( 1 , 0 );
Devide( root );
printf( "%d\n" , ans );
return 0;
}
```
练习:
[P3806](https://www.luogu.com.cn/problem/P3806)