线段树合并记

· · 算法·理论

线段树合并听起来十分高端,但其实非常好理解。

对于两个维护区间和的线段树 T_1,T_2 的合并,大致如下图:

其实就是将两棵线段树的相同位置的信息合并。假设我们仍要使用 T_1 存储合并后的线段树,即合并是将 T_2 合并到 T_1 上。

满线段树的合并显然是 \mathcal{O}(n) 的。但是实际上需要线段树合并的题一般是动态开点的,所以不会是满的。

在线段树不满的时候,考虑尽可能的剪枝,考虑从根开始 dfs,合并两树相同位置的信息,对于两树同个位置,分为四种情况:

  1. T_1,T_2 上均没有该节点,此时可以直接返回。
  2. T_1 上有该结点,而 T_2 没有。此时可以直接返回,因为 T_2 没有东西让我们继续合并。
  3. T_1 上没有该结点,而 T_2 上有。此时可以将 T_2 的对应子树直接拼在 T_1 上,然后直接返回。这里只需修改儿子指针即可。
  4. T_1,T_2 上均有该结点,考虑将两个结点信息合并,然后分别向左右儿子递归合并。

观察到,对于一个位置,只要不是在 T_1,T_2 上均有该结点,我们就不会递归下去。所以单次线段树合并的运算次数是两树的交集大小。

假如我们要合并结点数为 s_1 ,s_2 ,\dots ,s_mm 棵线段树。两棵线段树 i,j 的交集显然 \le \min\{s_i,s_j\}。故依次合并这些线段树的运算次数 \le \sum \limits _{i=2}^{m}\min\{s_i,\sum \limits _{j=1}^{i-1}s_j\} \le \sum _{i=1}^{m}\limits s_i。也就是线段树合并的复杂度不超过线段树结点数之和。

对线段树合并最常见到情况进行复杂度分析:

假如现在有 m 棵线段树分别进行过了 \mathcal{O}(1) 次操作,即每棵线段树上有 \mathcal{O}(\log n) 个结点,故所有线段树中结点数之和为 \mathcal{O}(m \log n)。所以复杂度也为 \mathcal{O}(m \log n)。假如有更加复杂的情况也可以使用类似的方法分析。

假如 T_1,T_2 中的信息需要复用,则需要在开出一个 T_3 来存合并后的信息。

例题 1

给定一棵 n 个结点的树,有 m 次操作,每次操作会往从 xy 的路径上放一袋类型为 z 的粮。操作结束后问你每个结点上最多的是那种粮。

考虑树上差分,在每个位置维护一个桶。每次在 x,y 处给 z 的数量 +1,在 \mathrm{LCA}(x,y)f_{\mathrm{LCA}(x,y)} 处给 z 的数量 -1

然后自下而上将儿子桶中所有的树加到其父亲上。直接做可以做到 \mathcal{O}(n^2)

考虑将每个位置的桶转成权值线段树,然后桶中值加到父亲的过程可以使用线段树合并。

由于每次操作只有 \mathcal{O}(1) 次线段树操作,所以复杂度为 \mathcal{O}(m \log n)

由于代码比较丑就先不放了,不知道代码怎么写的可以参考后面放的代码。

例题 2

给定一棵 n 个叶子的二叉树,叶子上有值,且所有叶子的值构成 1 \sim n 的排列。你可以交换任意结点的儿子,最小化按中序遍历顺序写下叶子上值的排列的逆序对数。

自下而上考虑是否交换儿子。

对于一个结点 x 而言,构成逆序对的两点有三种情况:

  1. 两个值均在左子树。
  2. 两个值均在右子树。
  3. 两个值分别在左子树和右子树。

对于 1. 和 2. 两种情况的逆序对已然无法改变,我们能最小化的只有 3. 构成的逆序对。

考虑维护两个桶 LR,分别表示左子树和右子树中各个数的出现次数。

不交换的逆序对数为 e_1=\sum \limits _{i=1}^{n}\sum \limits _{j=1}^{i-1}L_iR_j。交换则是 e_2=\sum \limits _{i=1}^{n}\sum \limits _{j=i+1}^{n}L_iR_j

使用线段树合并维护这些桶,合并的途中我们可以顺便计算出 e_1,e_2

可以发现这里是否交换和以后新的逆序对是否构成都无关,所以我们可以让答案直接加上 \min\{e_1,e_2\}

时间复杂度为 \mathcal{O}(n \log n)

放这题的用意是说,看见题目中出现叶子、二叉树的情况都有很大的概率是线段树合并。以及本题代码很短。

如果题目中的树不是二叉树就说明要么儿子之间影响是无用或简单的,要么说明本题不是线段树合并。

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=200005;
int n,a1,a2,ans,cnt;
struct node{signed ls,rs,ct;}t[N*20];
inline int merge(int x,int y){
    if(!x||!y)return x|y;
    a1+=1ll*t[t[x].ls].ct*t[t[y].rs].ct,a2+=1ll*t[t[x].rs].ct*t[t[y].ls].ct;
    t[x].ls=merge(t[x].ls,t[y].ls),t[x].rs=merge(t[x].rs,t[y].rs);
    return t[x].ct=t[t[x].ls].ct+t[t[x].rs].ct,x;
}
inline void gt(signed &id,int l,int r,int x){
    if(!id)id=++cnt;++t[id].ct;if(l==r)return;int mid=l+r>>1;
    x<=mid?gt(t[id].ls,l,mid,x):gt(t[id].rs,mid+1,r,x);
}
inline int sol(){
    signed pos=0,t;cin>>t;
    if(!t){
        int ls=sol(),rs=sol();
        a1=a2=0,pos=merge(ls,rs),
        ans+=min(a1,a2);
    }else gt(pos,1,n,t);
    return pos;
}
signed main(){
    cin>>n,sol(),cout<<ans<<'\n';
    return 0;
}

例题 3

有一颗 n 个结点的二叉树。叶子上有互不相同的权值已经确定,非叶子结点 x 的权值有 p_x 的概率是所有儿子权值的最大值,有 (1-p_x) 的概率是所有儿子权值的最小值。求根节点每种权值的出现概率。

考虑树上期望 dp。设 f_{i,j} 为结点 i 的权值为 j 的概率。

转移讨论两种情况,一种是取 \max 的情况,此时就是一边 =j,一边 \le j。另一种是取 \min 的情况,此时是一边 =j,一边 \ge j。然后减去重复出现的两边同时 =j,不过因为叶子权值不同所以不会算重。

具体而言整理后可以得到转移方程为:

f_{i,j}=f_{l,j}((1-p_i)\sum _{k=1}^{j-1}f_{r,k}+p_i\sum _{k=j+1}^{n}f_{r,k})+f_{r,j}((1-p_i)\sum _{k=1}^{j-1}f_{l,k}+p_i\sum _{k=j+1}^{n}f_{l,k})

直接做复杂度为 \mathcal{O}(n^2)。考虑使用线段树合并优化转移。

在每个结点处维护一个线段树存下 f_i 的值。合并的时候同时我们两边线段树到当前位置的前缀和以及后缀和。然后就是对着式子模拟即可。

时间复杂度为 \mathcal{O}(n \log n)

#include<bits/stdc++.h>
#define int long long
#define rd read()
#define gc pa == pb && (pb = (pa = buf) + fread(buf, 1, 100000, stdin), pa == pb) ? EOF : *pa++
using namespace std;
static char buf[100000], * pa(buf), * pb(buf);
inline int read(){
    unsigned int x=0,s=gc;while(!isdigit(s))s=gc;
    while(isdigit(s))x=(x<<1)+(x<<3)+(s^48),s=gc;
    return x;
}
const int mod=998244353;
inline int fpow(int x,int y=mod-2){
    int res=1;
    for(;y;y>>=1,(x*=x)%=mod)
        if(y&1)(res*=x)%=mod;
    return res;
}
const int N=300005,inv=fpow(10000);
struct node{
    int ls,rs,sm,tg;
}t[N<<5];
int n,m,ans,cnt,tot,rt[N],p[N],d[N];vector<int> v[N];
inline int New(){t[++cnt]={0,0,0,1};return cnt;}
inline void push(int id,int w){(t[id].tg*=w)%=mod,(t[id].sm*=w)%=mod;}
inline void pushdown(int id){if(t[id].tg!=1)push(t[id].ls,t[id].tg),push(t[id].rs,t[id].tg),t[id].tg=1;}
inline void U(int &id,int l,int r,int x){
    ++t[id=New()].sm;if(l==r)return;int mid=l+r>>1;
    x<=mid?U(t[id].ls,l,mid,x):U(t[id].rs,mid+1,r,x);
}
inline int merge(int x,int y,int l,int r,int X,int Y,int P){
    if(!x)return push(y,X),y;if(!y)return push(x,Y),x;
    pushdown(x),pushdown(y);int mid=l+r>>1,yls=t[t[y].ls].sm,xls=t[t[x].ls].sm;
    t[x].ls=merge(t[x].ls,t[y].ls,l,mid,(X+(1-P+mod)*t[t[x].rs].sm)%mod,(Y+(1-P+mod)*t[t[y].rs].sm)%mod,P);
    t[x].rs=merge(t[x].rs,t[y].rs,mid+1,r,(X+P*xls)%mod,(Y+P*yls)%mod,P);
    return t[x].sm=(t[t[x].ls].sm+t[t[x].rs].sm)%mod,x;
}
inline void dfs(int x){
    if(!v[x].size())return U(rt[x],1,tot,p[x]);dfs(v[x][0]);
    if(v[x].size()>1)dfs(v[x][1]),rt[x]=merge(rt[v[x][0]],rt[v[x][1]],1,tot,0,0,p[x]);
    else rt[x]=rt[v[x][0]];
}
inline void get(int id,int l,int r){
    if(!id)return;if(l==r)return (ans+=l*d[l]%mod*t[id].sm%mod*t[id].sm)%=mod,void();
    int mid=l+r>>1;pushdown(id);get(t[id].ls,l,mid),get(t[id].rs,mid+1,r);
}
signed main(){
    n=rd;
    for(int i=1;i<=n;++i)v[rd].push_back(i);
    for(int i=1;i<=n;++i)p[i]=v[i].size()?rd*inv%mod:d[++tot]=rd;
    stable_sort(d+1,d+tot+1);
    for(int i=1;i<=n;++i)if(!v[i].size())p[i]=lower_bound(d+1,d+tot+1,p[i])-d;
    dfs(1),get(rt[1],1,tot);cout<<ans<<'\n';return 0;
}

例题 4

给定一棵 n 个结点的树,你可以对每条边染色为黑或白,有 m 条限制 (u,v) 表示链 (u,v) 上有至少一个黑点,其中 vu 的祖先。求染色方案数。

与题面定义不同,下面均有 uv 的祖先。

考虑到对于两个限制 (u,v_1),(u,v_2),其中 u 为下方的点,且 \mathrm{dep}_{v_1}> \mathrm{dep}_{v_2}。有用的限制显然只有 (u,v_1),因为当 (u,v_1) 满足时 (u,v_2) 也会被满足。

不妨设 f_{x,i}u 在以 x 为根的子树中,且 v 不在子树内,即 x 在路径 (u,v) 上,未满足限制中 v 的最大深度为 i 的方案数。其中 f_{x,0} 表示合法方案数。

此时我们发现,当我们在 x 连父亲的边放 1f_x 中的所有限制都被满足。

对于一个点 x 的若干儿子 y_1 ,y_2\ldots y_k,考虑如何将信息合并。

当前所有形如 (x,y_i) 的边都不确定,而在确定这些边的过程当中儿子 y_i 之间独立,考虑依次合并儿子的信息到 x 的信息上。

对于其中一条边 (x,y\in \{y_1,y_2\ldots y_k\}) 而言,分别讨论将其取 01 的情况,然后将两种情况的贡献相加。

当这条边取 1 时,清空了子树 y 中的限制,所以对 f_{x,i} 的贡献为 \sum \limits _{j=0}^{\mathrm{dep}_x}f_{x,i}f_{y,j}

当这条边取 0 时,子树 y 中的限制不变,讨论谁剩下的深度更大即可,这种情况对答案的贡献为 \sum \limits _{j=0}^{i-1}f_{x,i}f_{y,j}+\sum \limits _{j=0}^{i}f_{x,j}f_{y,i}

整理后得到转移方程为:

f_{x,i}\gets f_{x,i}(\sum \limits _{j=0}^{\mathrm{dep}_x}f_{y,j}+\sum \limits _{j=0}^{i-1}f_{y,j})+f_{y,i}\sum \limits _{j=0}^{i}f_{x,j}

考虑使用线段树合并优化转移,只需在合并的同时维护前缀和即可。

时空复杂度均为 \mathcal{O}(n \log n)

#include<bits/stdc++.h>
#define int long long
#define rd read()
#define gc pa == pb && (pb = (pa = buf) + fread(buf, 1, 100000, stdin), pa == pb) ? EOF : *pa++
using namespace std;
static char buf[100000], * pa(buf), * pb(buf);
inline int read(){
    unsigned int x=0,s=gc;while(!isdigit(s))s=gc;
    while(isdigit(s))x=(x<<1)+(x<<3)+(s^48),s=gc;
    return x;
}
const int N=500005,mod=998244353;
int n,m,cnt,rt[N],dep[N];
vector<int> v[N],g[N];
struct node{int ls,rs,S,T;}t[N<<7];
inline int New(){return ++cnt;}
inline void chk(int &x,int y){x=x<y?y:x;}
inline void push(int id,int w){(t[id].T*=w)%=mod,(t[id].S*=w)%=mod;}
inline void pushdown(int id){if(t[id].T!=1)push(t[id].ls,t[id].T),push(t[id].rs,t[id].T),t[id].T=1;}
inline int merge(int x,int y,int l,int r,int s1,int s2){
    if(!x&&!y)return 0;if(!x)return push(y,s1),y;if(!y)return push(x,s2),x;
    if(l==r)return t[x].S=(t[y].S*(s1+t[x].S)%mod+t[x].S*s2)%mod,x;
    int mid=l+r>>1;pushdown(x),pushdown(y);
    t[x].rs=merge(t[x].rs,t[y].rs,mid+1,r,s1+t[t[x].ls].S,s2+t[t[y].ls].S);
    t[x].ls=merge(t[x].ls,t[y].ls,l,mid,s1,s2);
    t[x].S=(t[t[x].ls].S+t[t[x].rs].S)%mod;return x;
}
inline void U(int &id,int l,int r,int x){
    id=++cnt,t[id].T=t[id].S=1;if(l==r)return;int mid=l+r>>1;
    x<=mid?U(t[id].ls,l,mid,x):U(t[id].rs,mid+1,r,x);
}
inline int Q(int id,int l,int r,int x){
    if(!id)return 0;if(r<=x)return t[id].S;pushdown(id);int mid=l+r>>1;
    return (Q(t[id].ls,l,mid,x)+(x>mid?Q(t[id].rs,mid+1,r,x):0))%mod;
}
inline void dfs(int x,int f){
    int mx=0;dep[x]=dep[f]+1;
    for(int i:g[x])chk(mx,dep[i]);U(rt[x],0,n,mx);
    for(int i:v[x])if(i!=f)
        dfs(i,x),rt[x]=merge(rt[x],rt[i],0,n,0,Q(rt[i],0,n,dep[x]));
}
signed main(){
    n=rd;
    for(int i=1,x,y;i<n;++i)
        x=rd,y=rd,
        v[x].push_back(y),v[y].push_back(x);
    m=rd;
    for(int i=1,x,y;i<=m;++i)
        x=rd,y=rd,
        g[y].push_back(x);
    dfs(1,0),cout<<Q(rt[1],0,n,0)<<'\n';return 0;
}

例题 5

给定一颗 n 个结点的树,树有点权。问树上有多少个不同的连通块满足连通块中的不同点权数 \le 2

联通块计数,对于所有的联通块考虑在其最浅的结点处统计,相当于我们已经确定了一个结点的位置,即我们在连通块中已经有了一个颜色。

记录 f_{i,j} 表示连通块只存在颜色 a_ij 的方案数。特别的,当 a_i=j 时,连通块只有颜色 a_i

依次考虑 i 的所有儿子 y_1 ,y_2\ldots y_k,将它们的信息合并到 i 上。考虑至少选择了两个结点 iy_j 的方案。

a_i \ne a_y,加入的其他结点不可以有除 a_i,a_y 外的其他颜色。所以此时可以得到 f_{i,a_y}=(f_{i,a_i}+f_{i,a_y})(f_{y,a_i}+f_{y,a_y})+f_{i,a_i}。对于其他的 f_{i,j} 满足 j \ne a_y 则是不变的。

a_i = a_y,我们还可以在新增一个颜色,考虑去枚举这个颜色 x。对于 x \ne a_if_{i,x}=f_{i,x}(f_{y,a_y}+f_{y,x})+f_{i,a_i}f_{y,x}+f_{i,x}。而当 x = a_i 时则是 f_{i,x}=f_{i,x}f_{y,x}+f_{i,x}

使用线段树合并维护即可,时间复杂度为 \mathcal{O}(n \log n)

#include<bits/stdc++.h>
#define int long long
#define rd read()
#define gc pa == pb && (pb = (pa = buf) + fread(buf, 1, 100000, stdin), pa == pb) ? EOF : *pa++
using namespace std;
static char buf[100000], * pa(buf), * pb(buf);
inline int read(){
    register int x=0,ss=1,s=gc;
    while(!isdigit(s)&&s!='-')s=gc;if(s=='-')ss=-1,s=gc;
    while(isdigit(s))x=(x<<1)+(x<<3)+(s^48),s=gc;
    return ss*x;
}
const int N=500005,M=N*40,mod=998244353;
int n,cnt,ans,a[N],rt[N],ls[M],rs[M],tg[M],s[M];vector<int> v[N];
inline void push(int id,int v){
    if(id)(s[id]*=v)%=mod,(tg[id]*=v)%=mod;
}
inline void pushdown(int id){
    if(tg[id]!=1)push(ls[id],tg[id]),push(rs[id],tg[id]),tg[id]=1;
}
inline void U(int &id,int l,int r,int x,int y){
    if(!id)id=++cnt;(s[id]+=y)%=mod;if(l==r)return;int mid=l+r>>1;
    pushdown(id),x<=mid?U(ls[id],l,mid,x,y):U(rs[id],mid+1,r,x,y);
}
inline int Q(int id,int l,int r,int x){
    if(!id)return 0;if(l==r)return s[id];int mid=l+r>>1;
    pushdown(id);return x<=mid?Q(ls[id],l,mid,x):Q(rs[id],mid+1,r,x);
}
inline void P(int id,int l,int r,int x,int y,int k){
    if(!id)return;if(x<=l&&y>=r)return push(id,k);int mid=l+r>>1;pushdown(id);
    if(x<=mid)P(ls[id],l,mid,x,y,k);if(y>mid)P(rs[id],mid+1,r,x,y,k);s[id]=s[ls[id]]+s[rs[id]];
}
inline int merge(int x,int y,int l,int r,int h1,int h2,int t){
    if(!y)return push(x,h1+1),x;if(!x)return push(y,h2),y;
    if(l==r)return l==t?(s[x]*=(s[y]+1))%=mod:
        (s[x]+=s[x]*(h1+s[y])+h2*s[y])%=mod,x;
    int mid=l+r>>1;pushdown(x),pushdown(y);
    ls[x]=merge(ls[x],ls[y],l,mid,h1,h2,t);
    rs[x]=merge(rs[x],rs[y],mid+1,r,h1,h2,t);
    return s[x]=s[ls[x]]+s[rs[x]],x;
}
inline void dfs(int x,int f){
    U(rt[x],1,n,a[x],1);
    for(int i:v[x])if(i!=f){
        dfs(i,x);
        if(a[x]!=a[i]){
            int A=Q(rt[x],1,n,a[x]),B=Q(rt[x],1,n,a[i]),
                C=Q(rt[i],1,n,a[x]),D=Q(rt[i],1,n,a[i]);
            U(rt[x],1,n,a[i],(A+B)*(C+D));
        }else{
            int A=Q(rt[i],1,n,a[i]),B=Q(rt[x],1,n,a[x]);
            rt[x]=merge(rt[x],rt[i],1,n,A,B,a[x]);
        }
    }
    (ans+=s[rt[x]])%=mod;
}
signed main(){
    for(int i=1;i<M;++i)tg[i]=1;n=rd;
    for(int i=1;i<=n;++i)a[i]=rd;
    for(int i=1,x,y;i<n;++i)
        x=rd,y=rd,
        v[x].push_back(y),
        v[y].push_back(x);
    dfs(1,0),cout<<ans<<'\n';return 0;
}

例题 6

给定一棵 n 个结点的有向树,点有点权。有一个集合,一开始为空,m 次操作:在集合中加入或删除一个点,或给出一个点 x,询问集合中满足可以走到 x 的点的点权之和的历史最大值。

本题难点在于如何刻画“可以走到 x 的点”这样的条件。

单点插入或删除即单点点权的加减,考虑将单点加转换为对所有可以走到的点加,查询变为单点查。

然后考虑对于每个点求出其在每个时刻下的点权,查询答案即为前缀最大值。

在每个位置对时间轴维护一个线段树,初始时存下该结点在各时刻点权的变化量。然后按拓扑序将线段树合并,将变化量传递下去,被传递的点又会继续传下去,直到碰见某个无法到达的点。

然后就做完了,时间复杂度为 \mathcal{O}(m \log n)

#include<bits/stdc++.h>
#define int long long
#define rd read()
#define gc pa == pb && (pb = (pa = buf) + fread(buf, 1, 100000, stdin), pa == pb) ? EOF : *pa++
using namespace std;
static char buf[100000], * pa(buf), * pb(buf);
inline int read(){
    unsigned int x=0,s=gc;while(!isdigit(s))s=gc;
    while(isdigit(s))x=(x<<1)+(x<<3)+(s^48),s=gc;
    return x;
}
const int N=200005;
int n,m,cnt,a[N],rt[N],lst[N],op[N],X[N],in[N];vector<int> v[N];
struct node{int ls,rs,mx,tg;}t[N<<6];
inline void chkmax(int &x,int y){x=x<y?y:x;}
inline void pushup(int id){t[id].mx=max(t[t[id].ls].mx,t[t[id].rs].mx)+t[id].tg;}
inline void push(int id,int w){t[id].mx+=w,t[id].tg+=w;}
inline void U(int &id,int l,int r,int x,int y,int w){
    if(!id)id=++cnt;if(x<=l&&y>=r)return push(id,w);int mid=l+r>>1;
    if(x<=mid)U(t[id].ls,l,mid,x,y,w);if(y>mid)U(t[id].rs,mid+1,r,x,y,w);pushup(id);
}
inline int Q(int id,int l,int r,int y){
    if(!id)return 0;if(y>=r)return t[id].mx;int mid=l+r>>1,res=0;
    res=Q(t[id].ls,l,mid,y);if(y>mid)chkmax(res,Q(t[id].rs,mid+1,r,y));
    return res+t[id].tg;
}
inline int merge(int x,int y){
    if(!x||!y)return x|y;int o=++cnt;t[o]=t[x],push(o,t[y].tg);
    t[o].ls=merge(t[o].ls,t[y].ls),t[o].rs=merge(t[o].rs,t[y].rs);
    return pushup(o),o;
}
signed main(){
    n=rd,m=rd;for(int i=1;i<=n;++i)a[i]=rd;
    for(int i=1,x;i<n;++i)v[rd].push_back(x=rd),++in[x];
    for(int i=1;i<=m;++i){
        op[i]=rd,X[i]=rd;
        if(op[i]==1)lst[X[i]]=i;
        if(op[i]==2)U(rt[X[i]],1,m,lst[X[i]],i-1,a[X[i]]),lst[X[i]]=0;
    }
    for(int i=1;i<=n;++i)if(lst[i])U(rt[i],1,m,lst[i],m,a[i]);
    queue<int> q;for(int i=1;i<=n;++i)if(!in[i])q.push(i);
    while(q.size()){
        int u=q.front();q.pop();
        for(int i:v[u]){
            --in[i],rt[i]=merge(rt[i],rt[u]);
            if(!in[i])q.push(i);
        }
    }
    for(int i=1;i<=m;++i)if(op[i]==3)cout<<Q(rt[X[i]],1,m,i)<<'\n';return 0;
}