DS 笔记:珂朵莉树学习笔记

· · 个人记录

珂朵莉最可爱啦。

引子

CF896C Willem, Chtholly and Seniorious

要求实现一种数据结构,能维护下列四个操作:

  • 区间赋值为给定数 x
  • 区间修改(加上给定数 x)。
  • 求区间第 k 小数。
  • 区间 k 次幂取模。

具体的实现要求可以看原题。

乍一看这题题目要求较多,简单就此题要求而言,树状数组和线段树、平衡树看起来难以胜任,似乎需要一种很复杂的 Data Structure 来维护,但是赛后给出的题解里,代码实现的数据结构却异常简单——更具体地,甚至有些暴力,这就是珂朵莉树的来源,英文名字 ODT(Old Driver Tree)。

能使用 ODT 的题目大概存在有区间推平(赋值)操作的题目,但是这一切有一个前提:数据随机。所谓推平,便是将一个区间内的数据赋为统一值。而数据随机是因为珂朵莉树本质上是一种暴力修改,是瞄准了数据随机可能出现大段重复段的特性进行维护。虽说 ODT 看起来并没有树形结构的样子,但是其实现用到了 std::set 因此被称为 Tree。

ODT 核心思想

ODT 的主要思想是随机数据很可能使得一大段连续数字相同,因此我们可以把这一整段当成同一个数来处理。因此我们把一个序列以三元组 (l,r,v) 的形式存储,即一段 [l,r] 序列的值均为 v,譬如下一段序列可以如图表示:

因此我们需要定义一个结构体来存储,一个注意点是 v 需要定义为可修改类型,否则在之后的 std::set 中修改会被认为是语法错误的行为。同样,为了在 set 中能够自动以 \log 复杂度排序,我们需要重载运算符。

struct node{
    int l,r;
    mutable int v;
    node(int l,int r=-1,int v=0):l(l),r(r),v(v){}
    bool operator<(const node &x)const{return l<x.l;}
};

在结构体中,我们还实现了 node(l,r,v) 形式赋值的操作。

分裂区间

我们接下来实现 split 操作,这是 ODT 的核心操作,其思想是把一个连续区间断开处理。由于我们接受的操作并不一定是刚好覆盖整个区间的,因此我们需要断开区间,对于在操作范围内的数单独处理:

it split(int pos){
    it x=tr.lower_bound(node(pos,0,0));
    if(x!=tr.end()&&x->l==pos)return x;
    --x;int l=x->l,r=x->r,v=x->v;
    tr.erase(x);
    tr.insert(node(l,pos-1,v));
    return tr.insert(node(pos,r,v)).first;
}

其中 it 通过宏定义定义为 set<node>::iterator 也就是 set 类型的迭代器,类似于一个指针。我们做的操作便是把 [l,r] 区间断开成 [l,pos)[pos,r] 两个区间。我们首先找到用 lower_bound 找到左端点大于等于 pos 的第一个节点。如果已经存在以 pos 为左端点的区间,那我们可以直接返回答案。反之就往前找一个节点(这就是我们要断开的区间),删掉它 [l,r] 并插入 [l,pos)[pos,r]。最终我们只需要返回以 pos 开头的区间对应迭代器即可。

对于上文提及的区间,如果要 split(4) 怎么办?

首先 lower_bound 找到 (6,7,3),发现并不是结尾,于是向前找到 (3,5,2)。这时候我们开始处理——删除这个节点,插入 (3,3,2)(4,5,2) 代表从 3\sim 34\sim 5 两个区间(他们合并起来就是原区间),这样就完成了一次区间的分裂。

区间操作

你可能要问,这有什么用?和原来的区间不是等价的吗?其实我们通过 split 操作可以把需要处理的区间单独拿出来处理而不影响其他范围的数据。先看一看区间推平(赋值):

void assign(int l,int r,int v){
    it end=split(r+1),begin=split(l);
    tr.erase(begin,end);
    tr.insert(node(l,r,v));
}

其中 split 的顺序不能颠倒,否则可能会 RE。这一步的精髓在于我们的推平复杂度很低——直接将这一整段的区间删除,然后用一个简单的节点来代替——恰好,整段区间推平之后的结果可以用一个节点来表示,这就是 ODT 的牛逼之处。由于数据随机,有大约 25\% 的操作是 assign,所以这使得整棵树存储的结点数下降较快,有效保证了随机意义下的优秀复杂度。

而剩下的操作则是赤裸裸的暴力,譬如区间加,便是一个个加过去(并没有什么优化):

void add(int l,int r,int v){
    it lt=split(l),rt=split(r+1);
    for(;lt!=rt;lt++)lt->v+=v;
}

关于区间第 k 小,我们可以把区间内所有节点全部存下来然后进行排序,之后一个一个取出来就好了:

int Rank(int l,int r,int k){
    vector<pii >v;v.clear();
    it lt=split(l),rt=split(r+1);
    for(;lt!=rt;lt++)v.push_back(pii(lt->v,lt->r-lt->l+1));
    sort(v.begin(),v.end());
    for(vector<pii >::iterator x=v.begin();x!=v.end();++x){
        k-=x->second;
        if(k<=0)return x->first;
    }return -1ll;
}

快速幂的实现只需要将每一个区间单点求幂的值乘上区间长度即可,唯一需要注意的便是底数要记得对 y 先进行一次取模,代码实现如下:

int power(int l,int r,int x,int y){
    it lt=split(l),rt=split(r+1);
    int ans=0;
    for(;lt!=rt;lt++)ans=(ans+(lt->r-lt->l+1)*ksm(lt->v,x,y))%y;
    return ans;
}

代码实现

这里给出该题的全文代码实现。

#include<bits/stdc++.h>
#define it set<node>::iterator
#define int long long 
#define pii pair<int,int>
#define m7 1000000007
#define N 1000005
using namespace std;
struct node{
    int l,r;
    mutable int v;
    node(int l,int r,int v):l(l),r(r),v(v){}
    bool operator<(const node &x)const{return l<x.l;}
};set<node>tr;
int ksm(int b,int p,int mod){
    int ans=1;b%=mod;
    while(p){
        if(p&1)ans=ans*b%mod;
        b=b*b%mod;p>>=1;
    }return ans;
}it split(int pos){
    it x=tr.lower_bound(node(pos,0,0));
    if(x!=tr.end()&&x->l==pos)return x;
    --x;int l=x->l,r=x->r,v=x->v;
    tr.erase(x);
    tr.insert(node(l,pos-1,v));
    return tr.insert(node(pos,r,v)).first;
}void assign(int l,int r,int v){
    it end=split(r+1),begin=split(l);
    tr.erase(begin,end);
    tr.insert(node(l,r,v));
}void add(int l,int r,int v){
    it lt=split(l),rt=split(r+1);
    for(;lt!=rt;lt++)lt->v+=v;
}int Rank(int l,int r,int k){
    vector<pii >v;v.clear();
    it lt=split(l),rt=split(r+1);
    for(;lt!=rt;lt++)v.push_back(pii(lt->v,lt->r-lt->l+1));
    sort(v.begin(),v.end());
    for(vector<pii >::iterator x=v.begin();x!=v.end();++x){
        k-=x->second;
        if(k<=0)return x->first;
    }return -1ll;
}int power(int l,int r,int x,int y){
    it lt=split(l),rt=split(r+1);
    int ans=0;
    for(;lt!=rt;lt++)ans=(ans+(lt->r-lt->l+1)*ksm(lt->v,x,y))%y;
    return ans;
}int n,m,seed,vmax,a[N];
int rnd(){
    int ret=seed;
    seed=(seed*7+13)%1000000007;
    return ret;
}signed main(){
    scanf("%lld%lld%lld%lld",&n,&m,&seed,&vmax);
    for(int i=1;i<=n;i++){
        a[i]=(rnd()%vmax)+1;
        tr.insert(node(i,i,a[i]));
    }tr.insert(node(n+1,n+1,0));
    for(int i=1;i<=m;i++){
        int op=(rnd()%4)+1,l=(rnd()%n)+1,r=(rnd()%n)+1,x,y;
        if(l>r)swap(l,r);
        if(op==3)x=(rnd()%(r-l+1))+1;
        else x=rnd()%vmax+1;
        //以上代码为题目给定的随机数据方法 
        if(op==1)add(l,r,x);
        if(op==2)assign(l,r,x);
        if(op==3)printf("%lld\n",Rank(l,r,x));
        if(op==4){
            y=(rnd()%vmax)+1;
            printf("%lld\n",power(l,r,x,y));
        }
    }return 0;
}

相关例题

CF915E Physical Education Lessons

给定一个初始全 1 的长度为 n 的区间,要求实现两种操作:

  • 将其中一段全改为 1
  • 将其中一段全改为 0

这是一道非常简单可以用 ODT 实现的题目,我们抛弃先前实现的各种函数——因为这一道题目只需要区间推平就行,所以我们就使劲 assign 即可,代码实现很简单不到 1K:

#include<bits/stdc++.h>
#define it set<node>::iterator
#define int long long 
#define pii pair<int,int>
using namespace std;
struct node{
    int l,r;
    mutable int v;
    node(int l,int r,int v):l(l),r(r),v(v){}
    bool operator<(const node &x)const{return l<x.l;}
};set<node>tr;
int n,m,sum;
it split(int pos){
    it x=tr.lower_bound(node(pos,0,0));
    if(x!=tr.end()&&x->l==pos)return x;
    --x;int l=x->l,r=x->r,v=x->v;
    tr.erase(x);
    tr.insert(node(l,pos-1,v));
    return tr.insert(node(pos,r,v)).first;
}void assign(int l,int r,int v){
    it rt=split(r+1),lt=split(l),t=lt;
    for(;t!=rt;++t)sum-=t->v*(t->r-t->l+1);
    tr.erase(lt,rt);tr.insert(node(l,r,v));
    sum+=v*(r-l+1);
}signed main(){
    scanf("%lld%lld",&n,&m);
    tr.insert(node(1,n,1));sum=n;
    for(int i=1;i<=m;i++){
        int l,r,op;
        scanf("%lld%lld%lld",&l,&r,&op);
        if(op==1)assign(l,r,0);
        else assign(l,r,1);
        printf("%lld\n",sum);
    }return 0;
}

习题练手

P4344 [SHOI2015]脑洞治疗仪

给出三种操作:

  • 将一个区间赋值为 0
  • 将一个区间的 1 顺次填补到另一个区间内。
  • 统计最长的全 0 区间。

很显然,这道题在没有加强数据前可以用 ODT 实现。第一个操作只需要将区间推平为 0 即可。第二个操作的实现稍微有些复杂,实际上就是暴力统计前一个区间的 1 的个数,然后推平为 0,之后再顺次填入另一个区间内。对于第三个操作,我们只需要将每一个子段进行累加之后统计答案即可。这里给出实现:

#include<bits/stdc++.h>
#define it set<node>::iterator
#define int long long 
#define pii pair<int,int>
#define N 1000005
using namespace std;
struct node{
    int l,r;
    mutable int v;
    node(int l,int r,int v):l(l),r(r),v(v){}
    bool operator<(const node &x)const{return l<x.l;}
};set<node>tr;
it split(int pos){
    it x=tr.lower_bound(node(pos,0,0));
    if(x!=tr.end()&&x->l==pos)return x;
    --x;int l=x->l,r=x->r,v=x->v;
    tr.erase(x);
    tr.insert(node(l,pos-1,v));
    return tr.insert(node(pos,r,v)).first;
}void assign(int l,int r,int v){
    it end=split(r+1),begin=split(l);
    tr.erase(begin,end);
    tr.insert(node(l,r,v));
}void cure(int x,int y,int l,int r){
    it rt=split(y+1),lt=split(x);
    int sum=0;
    for(it t=lt;t!=rt;t++)if(t->v)sum+=(t->r-t->l+1);
    assign(x,y,0);
    rt=split(r+1),lt=split(l);
    for(it t=lt;t!=rt&&sum;t++){
        if(!t->v){
            int len=t->r-t->l+1;
            if(sum>len)t->v=1,sum-=len;
            else assign(t->l,t->l+sum-1,1),sum=0;
        }
    }
}int query(int l,int r){
    it rt=split(r+1),lt=split(l);
    int res=0,ans=0;
    for(it t=lt;t!=rt;t++){
        if(!(t->v))res+=(t->r-t->l+1);
        else ans=max(ans,res),res=0;
    }return max(ans,res);
}int n,q;
signed main(){
    scanf("%lld%lld",&n,&q);
    tr.insert(node(1,n,1));
    for(int i=1;i<=q;i++){
        int op,x,y,l,r;
        scanf("%lld%lld%lld",&op,&x,&y);
        if(op==0)assign(x,y,0);
        if(op==1)scanf("%lld%lld",&l,&r),cure(x,y,l,r);
        if(op==2)printf("%lld\n",query(x,y));
    }return 0;
}

可以顺利通过 Subtask 1 并拿到 0 分的好成绩,因为数据加强所以这道题目前只有线段树维护能过。

P2787 语文1(chin1)- 理理思维

维护一个字符串,有以下需求:

  • 获取第 x 到第 y 个字符中字母 k 出现次数。
  • 将第 x 到第 y 个字符赋值为字符 k
  • 将第 x 到第 y 个字符按字典序排序。

同样,我们对于加强前的数据可以用 ODT 实现,当中排序可以用桶排,复杂度在常数级别(因为总共也只有 26 个字母),同时由于此题大小写不敏感,所以我们需要先处理好字符串。

#include <bits/stdc++.h>
#define it set<node>::iterator
#define int long long 
using namespace std; 
struct node{
    mutable int l,r,v;
    node(int l,int r,int v):l(l),r(r),v(v){}
    bool operator<(const node &x)const{return l<x.l;}
};set<node>tr;
int change(char x){
    if('A'<=x&&x<='Z')return x-'A';
    return x-'a';
}it split(int pos){
    it x=tr.lower_bound(node(pos,0,0));
    if(x!=tr.end()&&x->l==pos)return x;
    --x;int l=x->l,r=x->r,v=x->v;
    tr.erase(x);
    tr.insert(node(l,pos-1,v));
    return tr.insert(node(pos,r,v)).first;
}int n,m,b[30];
char chr;
signed main(){
    scanf("%lld%lld",&n,&m);
    for(int i=1;i<=n;i++)
        cin>>chr,tr.insert(node(i,i,change(chr)));
    for(int i=1;i<=m;i++){
        int op,x,y,v;
        scanf("%lld%lld%lld",&op,&x,&y);
        it rt=split(y+1),lt=split(x);
        if(op==3){
            memset(b,0,sizeof b);
            while(lt!=rt){
                b[lt->v]+=lt->r-lt->l+1;
                it t=lt;lt++;tr.erase(t);
            }for(int i=0;i<26;i++)if(b[i]){
                tr.insert(node(x,x+b[i]-1,i));
                x+=b[i];
            }continue;
        }cin>>chr;v=change(chr);
        if(op==1){
            int res=0;
            for(;lt!=rt;lt++)if(lt->v==v)res+=lt->r-lt->l+1;
            printf("%lld\n",res);
        }if(op==2)tr.erase(lt,rt),tr.insert(node(x,y,v));
    }return 0;
}

上述代码可以通过除了最后一个强化测试点以外的数据。另外,此题的双倍经验 CF558E A Simple Task 要求实现降序和升序排序,而这依然是容易实现的,给出代码:

#include <bits/stdc++.h>
#define it set<node>::iterator
#define int long long 
using namespace std; 
struct node{
    mutable int l,r,v;
    node(int l,int r,int v):l(l),r(r),v(v){}
    bool operator<(const node &x)const{return l<x.l;}
};set<node>tr;
int change(char x){
    if('A'<=x&&x<='Z')return x-'A';
    return x-'a';
}it split(int pos){
    it x=tr.lower_bound(node(pos,0,0));
    if(x!=tr.end()&&x->l==pos)return x;
    --x;int l=x->l,r=x->r,v=x->v;
    tr.erase(x);
    tr.insert(node(l,pos-1,v));
    return tr.insert(node(pos,r,v)).first;
}int n,m,b[30];
char chr;
signed main(){
    scanf("%lld%lld",&n,&m);
    for(int i=1;i<=n;i++)
        cin>>chr,tr.insert(node(i,i,change(chr)));
    for(int i=1;i<=m;i++){
        int op,x,y;
        scanf("%lld%lld%lld",&x,&y,&op);
        it rt=split(y+1),lt=split(x);
        memset(b,0,sizeof b);
        while(lt!=rt){
            b[lt->v]+=lt->r-lt->l+1;
            it t=lt;lt++;tr.erase(t);
        }if(op){
            for(int i=0;i<=26;i++)if(b[i]){
                tr.insert(node(x,x+b[i]-1,i));
                x+=b[i];
            }
        }else{
            for(int i=26;i>=0;i--)if(b[i]){
                tr.insert(node(x,x+b[i]-1,i));
                x+=b[i];
            }
        }
    }for(it x=tr.begin();x!=tr.end();x++){
        for(int j=x->l;j<=x->r;j++){
            putchar(x->v+'a');
        }
    }puts("");
    return 0;
}

不才拙笔,感谢阅读。