点分治

· · 算法·理论

点分治,一种求树上路径长度的问题。时间复杂度为 O(n\log n)

前置知识:树的直径。

P4178

考虑暴力,直接枚举每两个端点。时间复杂度为 O(n^2)

但是大数据过不了,例如一条链就是无法过的。

考虑找到树的重心,这样子每个节点都是小于 \log n 的大小,这样子就可以继续递归,保证最后是 O(n\log n)

首先,找出树的重心。

然后,分两种情况讨论:经过重心和不经过重心。

然后不经过重心的就继续分治就好了。

例如,b-e 就是经过 a,但是 d-e 就没有经过,对于这种情况,我们要把 c 作为子树然后再求。

找重心代码:

void GetRoot( int x , int fa ) {
  f[ x ] = 0 , siz[ x ] = 1;
  for( int i = head[ x ] ; i ; i = e[ i ].nxt ) {
    int y = e[ i ].v;
    if( vis[ y ] || y == fa ) continue; //vis是否当过根节点
    GetRoot( y , x );
    siz[ x ] += siz[ y ];
    f[ x ] = max( f[ x ] , siz[ y ] );  //求最大子树大小
  }
  f[ x ] = max( f[ x ] , sum - siz[ x ] ); //剩下那一部分也要算上
  if( f[ x ] < f[ root ] ) root = x;  //寻找最小的
}

然后我们用树的重心跑出子树每个的深度值,放在 depth 数组里。

参考代码:

void GetDep( int x , int fa , int sum ) {
  dep[ ++tot ] = sum;
  for( int i = head[ x ] ; i ; i = e[ i ].nxt ) {
    int y = e[ i ].v;
    if( vis[ y ] || y == fa ) continue;
    GetDep( y , x , sum + e[ i ].w );
  }
}

然后将 depth 数组排序,我们就知道每个的距离了。

如果枚举每一个节点,那时间复杂度会增加。

所以要排序,然后通过双指针求出个数。

例如这个数组:

0 1 2 2 3
k=4

现在 l=1,r=5

$a_2+a_5=4$,答案数加 $5-2=3$。 $a_3+a_5>4$,超过了,$r$ 减一。 $a_3+a_4=4$,答案数加 $4-3=1$。 然后就搞定了。 注意,如果两个不经过重心,如 $d-e$,路径 $d-c-a-c-e$ 还是 $\le k$ 的话,我们就要把这个答案减掉,等到子树重心为 $c$ 时才算入答案。 完整代码: ```.cpp #include <bits/stdc++.h> using namespace std; const int _ = 4e4 + 5; inline int Read() { int x = 0 , f = 1; char c = getchar(); for( ; c < '0' || c > '9' ; c = getchar() ) f ^= ( c == '-' ); for( ; c >= '0' && c <= '9' ; c = getchar() ) x = ( x << 3 ) + ( x << 1 ) + ( c ^ 48 ); return f ? x : -x; } struct Edge { int v , w , nxt ; } e[_*2]; int head[_] , ecnt; void Add( int u , int v , int w ) { e[ ++ecnt ] = Edge{ v , w , head[ u ] } ; head[ u ] = ecnt ; } int n , k , root , sum , tot , ans; int f[_] , siz[_] , dep[_] , vis[_]; void GetRoot( int x , int fa ) { f[ x ] = 0 , siz[ x ] = 1; for( int i = head[ x ] ; i ; i = e[ i ].nxt ) { int y = e[ i ].v; if( vis[ y ] || y == fa ) continue; GetRoot( y , x ); siz[ x ] += siz[ y ]; f[ x ] = max( f[ x ] , siz[ y ] ); } f[ x ] = max( f[ x ] , sum - siz[ x ] ); if( f[ x ] < f[ root ] ) root = x; } void GetDep( int x , int fa , int l ) { dep[ ++tot ] = l; for( int i = head[ x ] ; i ; i = e[ i ].nxt ) { int y = e[ i ].v; if( vis[ y ] || y == fa ) continue; GetDep( y , x , l + e[ i ].w ); } } int Calc( int x , int L , int ans = 0 ) { tot = 0; GetDep( x , 0 , L ); sort( dep + 1 , dep + 1 + tot ); int l = 1 , r = tot; while( l < r ) { //排序双指针法求答案 if( dep[ l ] + dep[ r ] <= k ) { ans += r - l; l++; } else r--; } return ans; } void Devide( int x ) { //分治 vis[ x ] = 1; ans += Calc( x , 0 ); for( int i = head[ x ] ; i ; i = e[ i ].nxt ) { int y = e[ i ].v; if( vis[ y ] ) continue; ans -= Calc( y , e[ i ].w ); //减去重复的,当前总共为e[i].w root = 0 , sum = siz[ y ]; GetRoot( y , 0 ); Devide( root ); //继续分治 } } int main() { n = Read(); sum = f[ 0 ] = n; for( int i = 1 ; i < n ; i++ ) { int u = Read() , v = Read() , w = Read(); Add( u , v , w ) , Add( v , u , w ); } k = Read(); GetRoot( 1 , 0 ); Devide( root ); printf( "%d\n" , ans ); return 0; } ``` 练习: [P3806](https://www.luogu.com.cn/problem/P3806)