Splay

· · 个人记录

平衡树的一种写法。

当前点的size值是左儿子的size+右儿子的size+该节点元素出现的次数cnt

inline void pushup ( int u ) {
    t[ u ].size = t[ t[ u ].ch[ 0 ] ].size + t[ t[ u ].ch[ 1 ] ].size + t[ u ].cnt ; 
}
inline void rotate ( int x ) {
    int y = t[ x ].fa , z = t[ y ].fa ;
    int k = ( t[ y ].ch[ 1 ] == x ) ;//左儿子还是右儿子
    t[ z ].ch[ t[ z ].ch[ 1 ] == y ] = x ;
    t[ x ].fa = z ;
    t[ y ].ch[ k ] = t[ x ].ch[ k ^ 1 ] ;
    t[ t[ x ].ch[ k ^ 1 ] ].fa = y ;
    t[ x ].ch[ k ^ 1 ] = y ;
    t[ y ].fa = x ;
    pushup ( y ) ; pushup ( x ) ;
}
inline void splay ( int x , int g ) {
    while ( t[ x ].fa != g ) {
        int y = t[ x ].fa , z = t[ y ].fa ;
        if ( z != g ) {
            ( t[ y ].ch[ 0 ] == x ) ^ ( t[ z ].ch[ 0 ] == y ) ? rotate ( x ) : rotate ( y ) ;
        }
        rotate ( x ) ;
    }
    if ( g == 0 ) root = x ;
}
inline void insert ( int x ) {
    int u = root , fa = 0 ;
    while ( u && t[ u ].val != x ) {
        fa = u ;
        u = t[ u ].ch[ x > t[ u ].val ] ;
    }
    if ( u ) t[ u ].cnt ++ ;
    else {
        u = ++ tot ;
        if ( fa ) t[ fa ].ch[ x > t[ fa ].val ] = u ;
        t[ u ].ch[ 0 ] = t[ u ].ch[ 1 ] = 0 ;
        t[ tot ].fa = fa ;
        t[ tot ].val = x ;
        t[ tot ].cnt = t[ tot ].size = 1 ;
    }
    splay ( u , 0 ) ;
}
inline void find ( int x ) {
    int u = root ;
    if ( !u ) return ;
    while ( t[ u ].ch[ x > t[ u ].val ] && x != t[ u ].val ) u = t[ u ].ch[ x > t[ u ].val ] ;
    splay ( u , 0 ) ;
}
inline int Next ( int x , int f ) {
    find ( x ) ;
    int u = root ; 
    if ( t[ u ].val > x && f ) return u ;
    if ( t[ u ].val < x && !f ) return u ;
    u = t[ u ].ch[ f ] ;
    while ( t[ u ].ch[ f ^ 1 ] ) u = t[ u ].ch[ f ^ 1 ] ;
    return u ;
}
inline void Delete ( int x ) {
    int last = Next ( x , 0 ) ;
    int next = Next ( x , 1 ) ;
    splay ( last , 0 ) ; splay ( next , last ) ;
    int del = t[ next ].ch[ 0 ] ;
    if ( t[ del ].cnt > 1 ) {
        t[ del ].cnt -- ;
        splay ( del , 0 ) ;
    }
    else t[ next ].ch[ 0 ] = 0 ;
}
inline int kth ( int x ) {
    int u = root ;
    if ( t[ u ].size < x ) return 0 ;
    while ( 1 ) {
        int y = t[ u ].ch[ 0 ] ;
        if ( x > t[ y ].size + t[ u ].cnt ) {
            x -= t[ y ].size + t[ u ].cnt ;
            u = t[ u ].ch[ 1 ] ;
        }
        else {
            if ( t[ y ].size >= x ) u = y ;
            else return t[ u ].val ;
        }
    }
}

参考: