可持久化 Trie 学习笔记

· · 算法·理论

前置知识:Trie,可持久化线段树。

本文默认读者已经掌握以上两项内容。

详解

P4735 最大异或和

由于你谷没有正式意义上的模板题,所以我们以本题为例。

:::info[题意]{open} 给你一个初始长度为 n 的序列。有 m 次操作。

首先我们考虑简化版:假设现在有一个序列 a,每次给你一个数 x,让你求最大的 x\oplus a_i

这是一个经典的 0/1 Trie 问题,我们建立起序列 a 的 0/1 Trie 树,将 x 放在树上检索。每在树上往下一层代表进入了下一个二进制位。设 x 在这个位上的值为 bit,那么我们就优先考虑往 !bit 走,这样可以使得该位上异或和为 1 从而最优。如果 !bit 子树上没有点那么就只能往 bit 走。最后走到的点即为最优答案。

那么考虑这个题。首先区间异或 \bigoplus_{i=p}^n 很好解决,转化为前缀异或 s_n\oplus s_{p-1},那么问题转化为:给定区间 [l,r] 和数 x,需要找到一个位置 p\in[l-1,r-1],使得 (x\oplus s_n)\oplus s_p 最大并输出这个最大值。

这个问题相较于简化版的变化是:简化版是查询整个序列的“最大异或和”,而这个问题是查询序列上某个区间的“最大异或和“。

我们显然不能在每次询问时对于询问区间暴力建立 Trie 树再去做。我们考虑再进行前缀和优化:对于 s 序列的 1\sim i 号元素,建立版本 T_i,那么我们需要的 Trie 树即为版本 T_{r-1}-T_{l-2}

最后是建立新版本不用暴力枚举 [1,i] 建立因为这样复杂度会爆。可以从上一个版本 T_{i-1} 复制过来并插入新元素 s_i 即可。这个理解是和主席树相似的,这里就不予叙述了。

:::success[代码(有注释)]

#include<bits/stdc++.h>
using namespace std;
const int N=6e5+5;
int n,m,cnt,a[N],siz[N*62],ch[N*62][2],rt[N];
void ins(int &rt,int x,int d){//新建版本,rt是当前走到节点,x是新插入的数,d是当前深度
    ch[++cnt][0]=ch[rt][0],ch[cnt][1]=ch[rt][1];//其实就是复制上一个版本的节点
    siz[cnt]=siz[rt]+1,rt=cnt;//siz+1是更新新的节点的答案
    if(~d)ins(ch[rt][(x>>d)&1],x,d-1);//x在d位的值是哪个就往哪边走(去更新节点答案),没走的那边相当于是和上一个版本一样的数据,所以直接继承节点即可
}
int qy(int rt1,int rt2,int x,int d){//查询版本T[rt1]-T[rt2]的答案
    if(d==-1)return 0;//走到头了
    int pos=(x>>d)&1;pos^=1;//优先找!bit那一边
    if(siz[ch[rt1][pos]]-siz[ch[rt2][pos]])return qy(ch[rt1][pos],ch[rt2][pos],x,d-1)|(1<<d);//siz!=0表示有数,可以往下走,此时这一位的异或值为1,所以加上答案2^d
    else return qy(ch[rt1][pos^1],ch[rt2][pos^1],x,d-1);//否则只能往bit的那一边走
}
int main(){
    cin>>n>>m;
    ins(rt[0],0,25);
    for(int i=1;i<=n;i++)cin>>a[i],a[i]^=a[i-1],ins(rt[i]=rt[i-1],a[i],25);//rt[i]=rt[i-1]:继承
    while(m--){
        char op;int x,y,z,t;
        cin>>op>>x;
        if(op=='A')a[++n]=x,a[n]^=a[n-1],ins(rt[n]=rt[n-1],a[n],25);
        else{
            cin>>y>>z,t=a[n]^z,x--,y--;//x--,y--是因为用了前缀和,所以需要找的是前缀和序列[l-1,r-1]的答案
            if(x)cout<<qy(rt[y],rt[x-1],t,25)<<"\n";
            else cout<<qy(rt[y],0,t,25)<<"\n";
        }
    }
    return 0;
}

:::

总结

可持久化 Trie 的一个常见用途是:在序列上某一段区间 a_{[l,r]} 中找到与 x 的最大异或值。这个时候的 Trie 其实是 0/1 Trie。

扩展性地,处理的是在区间 a_{[l,r]} 中和 x 相关的一些特定信息。

如果有别的更高妙的引申或应用可以与我交流。

练习

题目不是很多。

P4592 [TJOI2018] 异或

:::info[题意]{open}

有一棵以 1 为根节点的由 n 个节点组成的树,节点从 1n 编号。树上每个节点上都有一个权值 v_i。现在有 q 次操作。

看到异或最大值很自然的想到可持久化 Trie。

对于第一个询问,使用 dfn 序将子树转化为一段连续的区间,然后用可持久化 Trie 维护这段区间的答案即可。

对于第二个询问,求出 x,y 的 LCA t。对于每一个节点建立其版本 T_x 表示从根节点到 x 的路径上所有点的 Trie 树集合。那么答案可以拆分成 t\to x 的答案和 t\to y 的答案的最大值。而 t\to x 的答案就是 T_{x}-T_{fa_t}

也就是说我们要维护两个可持久化 Trie,一个维护 dfn 区间处理第一个询问,一个维护树上的链处理第二个询问。

似乎第二个询问用树剖和 dfn 序也是可以的(?)但是这显然不如如上方法来的简单。

:::success[代码]

#include<bits/stdc++.h>
using namespace std;
const int N=1e5+5;
int n,q,tot,a[N],dfn[N],dep[N],sz[N],dp[N][25];
vector<int>v[N];
int cnt,siz[N*70],ch[N*70][2],rt1[N],rt2[N];
int LCA(int x,int y){
    if(dep[x]>dep[y])swap(x,y);
    for(int i=20;i>=0;i--)if(dep[x]<=dep[dp[y][i]])y=dp[y][i];
    if(x==y)return x;
    for(int i=20;i>=0;i--)if(dp[x][i]!=dp[y][i])x=dp[x][i],y=dp[y][i];
    return dp[x][0];
}
void ins(int &rt,int x,int d){
    ch[++cnt][0]=ch[rt][0],ch[cnt][1]=ch[rt][1];
    siz[cnt]=siz[rt]+1,rt=cnt;
    if(~d)ins(ch[rt][(x>>d)&1],x,d-1);
}
void dfs(int x,int fa){
    sz[x]=1,dp[x][0]=fa,dep[x]=dep[fa]+1,dfn[x]=++tot;
    for(int i=1;(1<<i)<=dep[x];i++)dp[x][i]=dp[dp[x][i-1]][i-1];
    ins(rt1[tot]=rt1[tot-1],a[x],30);
    ins(rt2[x]=rt2[fa],a[x],30);
    for(int i=0;i<v[x].size();i++){
        int to=v[x][i];
        if(to==fa)continue;
        dfs(to,x),sz[x]+=sz[to];
    }
    return;
}
int qy(int rt1,int rt2,int x,int d){
    if(d==-1)return 0;
    int pos=(x>>d)&1;pos^=1;
    if(siz[ch[rt1][pos]]-siz[ch[rt2][pos]])return qy(ch[rt1][pos],ch[rt2][pos],x,d-1)|(1<<d);
    else return qy(ch[rt1][pos^1],ch[rt2][pos^1],x,d-1);
}
int main(){
    cin>>n>>q;
    for(int i=1;i<=n;i++)cin>>a[i];
    for(int i=1,x,y;i<n;i++){
        cin>>x>>y;
        v[x].push_back(y);
        v[y].push_back(x);
    }
    dfs(1,0);
    for(int i=1,op,x,y,z,t;i<=q;i++){
        cin>>op>>x>>y;
        if(op==1){
            cout<<qy(rt1[dfn[x]+sz[x]-1],rt1[dfn[x]-1],y,30)<<"\n";
        }
        else cin>>z,t=dp[LCA(x,y)][0],cout<<max(qy(rt2[x],rt2[t],z,30),qy(rt2[y],rt2[t],z,30))<<"\n";
    }
    return 0;
}

:::

P6088 [JSOI2015] 字符串树

:::info[题意]{open}

给定一棵 n 个点的树,树上每条边上都有一个字符串。m 次询问,每次询问给定点 u,v 和字符串 s。求树上 u\to v 的路径上有多少个字符串以 s 为前缀。

保证所有字符串的长度都不超过 10。 :::

这个题其实和上面那个题差不多,都是处理树上路径的一些问题。只是这次我们要检索的是字符串,所以 0/1 Trie 变成了 Trie。

我们还是建立版本 T_x 表示树上点 x 到根路径上所有字符串建立的 Trie。T_x 可以从 T_{fa} 继承。假设 f_{u,s} 表示在版本 T_u 上检索有多少字符串是以 s 为前缀。那么直接在 T_u 上找到 s 对应的节点然后计算这个节点的 siz 即可。对于询问 u,v,s,答案就是 f_{u,s}+f_{v,s}-2\times f_{\text{lca}(u,v),s}

可能比上面那个题还容易一些?

:::success[代码]

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+5;
int n,m,cnt,dep[N],dp[N][25],T[N],ch[N*10][26],siz[N*10];
struct edge{int x;string s;};
vector<edge>v[N];
void ins(int &rt,string x,int d){
    cnt++;for(int i=0;i<26;i++)ch[cnt][i]=ch[rt][i];
    siz[cnt]=siz[rt]+1,rt=cnt;
    if(d<x.size())ins(ch[rt][x[d]-'a'],x,d+1);
    return;
}
void dfs(int x,int fa){
    dep[x]=dep[fa]+1,dp[x][0]=fa;
    for(int i=1;(1<<i)<=dep[x];i++)dp[x][i]=dp[dp[x][i-1]][i-1];
    for(auto it:v[x]){
        if(it.x==fa)continue;
        ins(T[it.x]=T[x],it.s,0),dfs(it.x,x);
    }
    return;
}
int lca(int x,int y){
    if(dep[x]>dep[y])swap(x,y);
    for(int i=20;i>=0;i--)if(dep[x]<=dep[dp[y][i]])y=dp[y][i];
    if(x==y)return x;
    for(int i=20;i>=0;i--)if(dp[x][i]!=dp[y][i])x=dp[x][i],y=dp[y][i];
    return dp[x][0];
}
int qy(int rt,string s){
    for(int i=0;i<s.size();i++){
        if(!ch[rt][s[i]-'a'])return 0;
        rt=ch[rt][s[i]-'a'];
    }
    return siz[rt];
}
signed main(){
    cin>>n;
    for(int i=1,x,y;i<n;i++){
        string s;cin>>x>>y>>s;
        v[x].push_back({y,s});
        v[y].push_back({x,s});
    }
    dfs(1,0);
    cin>>m;
    while(m--){
        int x,y;string s;
        cin>>x>>y>>s;
        cout<<qy(T[x],s)+qy(T[y],s)-2*qy(T[lca(x,y)],s)<<"\n";
    }
    return 0;
}

:::

P4098 [HEOI2013] ALO

:::info[题意]{open} 给定一个 n 个数的序列 a。我们定义一个区间(长度不小于 2)的权值为这个区间的次大值与这个区间内其他任意数的异或的值的最大值。

请你在 a 中选取一段区间使得这段区间的权值最大。输出这个最大值。 :::

感觉像 trick 拼好题。

有一个套路是固定一个次大值 a_i,然后找 a_i 可以在哪些区间成为次大值。

我们找到 a_i 左边和右边第一个和第二个比它大的位置 l_1,l_2,r_1,r_2,那么有两个区间可以包含所有情况,即 [l_1+1,r_2-1],[l_2+1,r_1-1]。我们分别找到这两个区间内部与 a_i 异或出来最大的值并取最值即可。这个是可持久化 Trie 的经典形式。

最后的问题是怎么找到 l_1,l_2,r_1,r_2,这个也很套路。我们用链表维护所有点,从小到大删所有点,删到一个点时,其左右两边的两个点就是答案。

需要特判一下如果左右两边不足两个点的情况。

:::success[代码]

#include<bits/stdc++.h>
#define int long long
#define pii pair<int,int>
#define fi first
#define se second
using namespace std;
const int N=5e4+5,Hd=0,Tl=5e4+1;
int n,ans,cnt,a[N],b[N],L[N],R[N],siz[N*70],ch[N*70][2],rt[N];
pii e[N],t1[N],t2[N];
void lk(int x,int y){return L[y]=x,R[x]=y,void();}
void del(int x){return lk(L[x],R[x]),void();}
void ins(int &rt,int x,int d){
    ch[++cnt][0]=ch[rt][0],ch[cnt][1]=ch[rt][1];
    siz[cnt]=siz[rt]+1,rt=cnt;
    if(~d)ins(ch[rt][(x>>d)&1],x,d-1);
}
int qy(int rt1,int rt2,int x,int d){
    if(d==-1)return 0;
    int pos=(x>>d)&1;pos^=1;
    if(siz[ch[rt1][pos]]-siz[ch[rt2][pos]])return qy(ch[rt1][pos],ch[rt2][pos],x,d-1)|(1<<d);
    else return qy(ch[rt1][pos^1],ch[rt2][pos^1],x,d-1);
}
signed main(){
    cin>>n,lk(n,Tl);
    for(int i=1;i<=n;i++)cin>>a[i],lk(i-1,i),e[i]={a[i],i};
    sort(e+1,e+n+1);
    for(int i=1;i<=n;i++){
        int x=e[i].se;
        if(L[x]==Hd){
            if(R[x]==Tl)t1[x]=t2[x]={-1,-1};
            else if(R[R[x]]==Tl)t1[x]={1,n},t2[x]={-1,-1};
            else t1[x]={1,R[R[x]]-1},t2[x]={-1,-1};
        }
        else if(L[L[x]]==Hd){
            if(R[x]==Tl)t1[x]={1,n},t2[x]={-1,-1};
            else{
                t1[x]={1,R[x]-1};
                if(R[R[x]]==Tl)t2[x]={L[x]+1,n};
                else t2[x]={L[x]+1,R[R[x]]-1};
            }
        }
        else{
            if(R[x]==Tl)t1[x]={L[L[x]]+1,n},t2[x]={-1,-1};
            else if(R[R[x]]==Tl)t1[x]={L[L[x]]+1,R[x]-1},t2[x]={L[x]+1,n};
            else t1[x]={L[L[x]]+1,R[x]-1},t2[x]={L[x]+1,R[R[x]]-1};
        }
        del(x);
    }
    ins(rt[0],0,31);
    for(int i=1;i<=n;i++)ins(rt[i]=rt[i-1],a[i],31);
    for(int i=1;i<=n;i++){
        int x=t1[i].fi,y=t1[i].se;
        if(x==-1)continue;
        ans=max(ans,qy(rt[y],rt[x-1],a[i],31));
        x=t2[i].fi,y=t2[i].se;
        if(x==-1)continue;
        ans=max(ans,qy(rt[y],rt[x-1],a[i],31));
    }
    cout<<ans;
    return 0;
}

:::

P5283 [十二省联考 2019] 异或粽子

:::info[题意]{open} 给定一个长度为 n 的序列 a 和一个整数 m。设一个区间 [l,r] 的权值为 \bigoplus_{i=l}^r a_i。试求序列内权值前 m 大的区间的权值之和。 :::

先做一个前缀异或,那么区间的权值是 s_r\oplus s_{l-1}

固定右端点 r,用可持久化 Trie 在 s_{[0,r-1]} 中找到与 s_r 异或值最大的那一个。初始时把每一个 r 的这个值都放入大根堆中。

每次在堆中取出值最大的那一个状态。设这个状态右端点为 r,对应的最优的左端点为 p,累加答案后把左端点区间分割成 s_{[0,p-1]}s_{[p+1,r-1]} 重新插入即可。

另外,我们需要记录 p 的位置而不是直接记录最大的 val。所以需要在模板上改动一点地方,读者可尝试自行实现。

:::success[代码]

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=5e5+5;
int n,m,ans,cnt,a[N],siz[N*35],ch[N*35][2],T[N],id[N*35];
void ins(int &rt,int x,int d,int idd){
    ch[++cnt][0]=ch[rt][0],ch[cnt][1]=ch[rt][1];
    siz[cnt]=siz[rt]+1,rt=cnt;
    if(~d)ins(ch[rt][(x>>d)&1],x,d-1,idd);
    else id[rt]=idd;
}
struct node{
    int x,p,l,r,tt;
    bool operator<(const node&X)const{return x<X.x;}
};
priority_queue<node>q;
int qy(int rt1,int rt2,int x,int d){
    if(d==-1)return id[rt1];
    int pos=(x>>d)&1;pos^=1;
    if(siz[ch[rt1][pos]]-siz[ch[rt2][pos]])return qy(ch[rt1][pos],ch[rt2][pos],x,d-1);
    else return qy(ch[rt1][pos^1],ch[rt2][pos^1],x,d-1);
}
signed main(){
    cin>>n>>m,ins(T[0],0,33,0);
    for(int i=1;i<=n;i++)cin>>a[i],a[i]^=a[i-1],ins(T[i]=T[i-1],a[i],33,i);
    for(int i=1;i<=n;i++){
        int p=qy(T[i-1],0,a[i],33);
        q.push({(a[i]^a[p]),p,0,i-1,i});
    }
    while(m--){
        node x=q.top();q.pop();
        ans+=x.x;
        if(x.l<x.p){
            int p=qy(T[x.p-1],(!x.l?0:T[x.l-1]),a[x.tt],33);
            q.push({(a[x.tt]^a[p]),p,x.l,x.p-1,x.tt});
        }
        if(x.p<x.r){
            int p=qy(T[x.r],T[x.p],a[x.tt],33);
            q.push({(a[x.tt]^a[p]),p,x.p+1,x.r,x.tt});
        }
    }
    cout<<ans;
    return 0;
}

:::

CF1777F Comfortably Numb

:::info[题意]{open} 给定一个 n 个数的序列 a。我们定义一个区间 [l,r] 的权值为 \max(a_l,\dots,a_r)\oplus (a_l\oplus\dots\oplus a_r)

请你在 a 中选取一段区间使得这段区间的权值最大。输出这个最大值。 :::

套路和上面那道 P4098 是类似的。框出每个值 i 作为最大值的区间 [L,R]。在 [L,i][i,R] 内选取 l,r,那么权值就是 a_i\oplus s_r \oplus s_{l-1}

这里有一个暴力的想法是枚举 [L,i],[i,R] 中长度较小的一边作为左右端点,另一边用可持久化 Trie 处理答案。比如说假设 [L,i] 长度较小,我们枚举 l\in [L,i],此时 a_i\oplus s_{l-1} 已经是定值。然后再在 s_{[i,R]} 中找最大的异或值。

这样的复杂度其实是对的。感性理解就是用启发式合并的思想,但是这里给出证明。

:::info[时间复杂度证明]{open} 对于原序列造出笛卡尔树,那么选取 [L,i],[i,R] 中长度较小的遍历本质上是选择树上节点 i 的左右儿子中 siz 较小的一个遍历。

设对于节点 i,其左右儿子分别为 ls,rs,不妨设 siz_{ls}<siz_{rs},则此时 ls 内的所有节点被遍历(计数)一次。

对于一个节点 x,它被计数当且仅当它所在子树作为较小子树被合并到父节点时。每次合并后,它新的所在子树大小至少翻倍(因为它所在的是较小子树),那么每个 x 都至多被计数 \log n 次,总的复杂度就是 \mathcal{O(n\log n)}。 :::

加上可持久化 Trie 的一只 \log,总复杂度是 \mathcal{O(n\log^2 n)} 的。

:::success[代码]

#include<bits/stdc++.h>
using namespace std;
const int N=2e5+5;
int T,n,cnt,ans,a[N],L[N],R[N],siz[N*35],ch[N*35][2],rt[N],s[N];
stack<int>stk;
void ins(int &rt,int x,int d){
    ch[++cnt][0]=ch[rt][0],ch[cnt][1]=ch[rt][1];
    siz[cnt]=siz[rt]+1,rt=cnt;
    if(~d)ins(ch[rt][(x>>d)&1],x,d-1);
}
int qy(int rt1,int rt2,int x,int d){
    if(d==-1)return 0;
    int pos=(x>>d)&1;pos^=1;
    if(siz[ch[rt1][pos]]-siz[ch[rt2][pos]])return qy(ch[rt1][pos],ch[rt2][pos],x,d-1)|(1<<d);
    else return qy(ch[rt1][pos^1],ch[rt2][pos^1],x,d-1);
}
void work(){
    cin>>n;
    for(int i=1;i<=n;i++)cin>>a[i],rt[i]=0;
    for(int i=1;i<=n;i++){
        while(!stk.empty()){
            if(a[stk.top()]<a[i])R[stk.top()]=i,stk.pop();
            else break;
        }
        if(!stk.empty())L[i]=stk.top();
        else L[i]=0;
        stk.push(i);
    }
    while(!stk.empty())R[stk.top()]=n+1,stk.pop();
    for(int i=1;i<=cnt;i++)ch[i][0]=ch[i][1]=siz[i]=0;
    cnt=ans=0,ins(rt[0],0,31);
    for(int i=1;i<=n;i++)L[i]++,R[i]--,s[i]=s[i-1]^a[i],ins(rt[i]=rt[i-1],s[i],31);
    for(int i=1;i<=n;i++){
        int siz1=i-L[i]+1,siz2=R[i]-i+1;
        if(siz1<siz2){
            for(int j=L[i];j<=i;j++){
                int val=a[i]^s[j-1];
                ans=max(ans,qy(rt[R[i]],rt[i-1],val,31));
            }
        }
        else{
            for(int j=i;j<=R[i];j++){
                int val=a[i]^s[j];
                ans=max(ans,qy(rt[i-1],(L[i]==1?0:rt[L[i]-2]),val,31));
            }
        }
    }
    cout<<ans<<"\n";
    return;
}
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    cin>>T;
    while(T--)work();
    return 0;
}

:::

后记

本人才疏学浅,如果有内容上的错漏可以私信与我交流。