左偏树学习笔记

· · 个人记录

Part 1 它是什么?

准确来说,它是一个

我们以小根堆为例,构建一个左偏树。

它有两个要求:

  1. 对于每棵子树来说,根节点的权值小于等于子节点的权值。
  2. 定义距离为一个节点到它最近的空节点的距离,那么每个子树中,左儿子的距离均大于等于右儿子的距离。 (黑色您可以理解成是节点权值/编号,红色为当前节点的距离)
一个证明:

我们可以证明一个东西:根节点距离 d \le \log(n)

以下是证明过程:

定义 f(k) 是距离为 k 的子树至少包含几个点。

我们用数学归纳法证明 f(k) \ge 2^k-1.

很显然,当 k=1 时,f(1)=1 \ge 2^1-1.

我们设 n=k-1 时上面成立,证明 n=k 时也成立。

根据图片,我们知道:

所以 $n=k$ 时也成立,故 $f(k) \ge 2^k-1$. 现在,我们需要找到最大的 $k$,使得 $2^k-1 \le n

解方程,k \le \log_2{(n+1)}

Part 2.它可以做什么?

左偏树支持以下四个操作。

  1. 插入一个树,复杂度O(\log n)
  2. 求最小值(以上面为例),复杂度O(1)
  3. 删除最小值(以上面为例),复杂度O(\log n)
  4. 合并两棵树,复杂度O(\log n)

核心操作——合并

合并两颗树 ab 的过程如下:

  1. 确认根节点,这里不妨设 a 根节点的值小于 b 根节点的值。
  2. 递归将 a 的右子树与 b 合并。
  3. 判断右子树距离是否小于左子树,否则将两者交换。
一个证明

我们来证明它的时间复杂度。

我们设 a 根节点距离 dist[a]b 根节点距离 dist[b]。 每一次递归,两者都会有一个减一,因此一共执行了 dist[a]+dist[b] 次。 因为 dist[a]dist[b] 均为 \log(n) 级别,因此时间复杂度 O(\log n)

有了合并操作,(1)(2)操作就很好解释了。(1)实际上是一个左偏树与另一个只有一个点的左偏树合并,(2)根据定义就是根节点的值。

(3)操作呢?实际上也很简单,将根节点删除后,将剩余的两棵子树合并就可以了。

Part 3.一些题目

P3377

板子题。 左偏树常常与并查集一起用,并且会进行换根操作。 具体细节见代码。

#include<bits/stdc++.h>
using namespace std;
const int N=200005;
typedef pair<int,int> PII;
int n,m,idx,par[N];
int l[N],r[N],dist[N],st[N];
PII w[N];
int find(int x)
{
    if(par[x]==x)return x;
    else return par[x]=find(par[x]);
}
int merge(int x,int y)
{
    if(!x||!y)return x+y;
    if(w[y]<w[x])swap(x,y);
    r[x]=merge(r[x],y);
    if(dist[r[x]]>dist[l[x]])swap(l[x],r[x]);
    dist[x]=dist[r[x]]+1;
    return x;
}
int main()
{
    dist[0]=-1;
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&w[i].first);
        par[i]=i;
        w[i].second=i;
    }
    while(m--)
    {
        int op,x,y;
        scanf("%d%d",&op,&x);
        if(op==1)
        {
            scanf("%d",&y);
            if(st[x]||st[y])continue;
            x=find(x);
            y=find(y);
            if(x==y)continue;
            par[x]=par[y]=merge(x,y);
        }
        else if(op==2)
        {
            if(st[x])
            {
                puts("-1");
                continue;
            }
            x=find(x);
            printf("%d\n",w[x].first);
            st[x]=1;
            par[l[x]]=par[r[x]]=par[x]=merge(l[x],r[x]);
            l[x]=r[x]=dist[x]=0;
        }
    }
    return 0;
}

Acwing 2714

与上一题类似。

#include<bits/stdc++.h>
using namespace std;
const int N=200005;
int n,idx,par[N];
int l[N],r[N],w[N],dist[N];
int cmp(int x,int y)
{
    if(w[x]!=w[y])return w[x]<w[y];
    return x<y;
}
int find(int x)
{
    if(par[x]==x)return x;
    else return par[x]=find(par[x]);
}
int merge(int x,int y)
{
    if(!x||!y)return x+y;
    if(cmp(y,x))swap(x,y);
    r[x]=merge(r[x],y);
    if(dist[r[x]]>dist[l[x]])swap(l[x],r[x]);
    dist[x]=dist[r[x]]+1;
    return x;
}
int main()
{
    w[0]=2e9;
    scanf("%d",&n);
    while(n--)
    {
        int op,x,y;
        scanf("%d%d",&op,&x);
        if(op==1)
        {
            w[++idx]=x;
            dist[idx]=1;
            par[idx]=idx;
        }
        else if(op==2)
        {
            scanf("%d",&y);
            x=find(x);
            y=find(y);
            if(x==y)continue;
            if(cmp(y,x))swap(x,y);
            par[y]=x;
            merge(x,y);   
        }
        else if(op==3)printf("%d\n",w[find(x)]);
        else
        {
            x=find(x);
            if(cmp(r[x],l[x]))swap(l[x],r[x]);
            par[x]=l[x];
            par[l[x]]=l[x];
            merge(l[x],r[x]);
        }
    }
    return 0;
}