浅谈 Trie 及其应用

· · 个人记录

浅谈 Trie 及其应用

什么是 Trie 树

想象一下翻英文字典的过程,是不是从首字母开始,一位一位的找找下去,最后找到你想要的词?Trie 树完成的就是这个工作。

比如我要插入一个 boy,先定位第一个字母,发现是 b,于是来到 0 的儿子 b,之后定位第二个字母 o,来到b 的儿子 o,最后定位第三个字母 y,发现 o 还没有y 这个儿子,就新建一个 y。Trie 树的每一个节点理论上都有 |字符集大小| 个儿子,这里就是 26 个,如果要查询一个串是否存在,顺着儿子往下走即可。

模板题

P2580 于是他错误的点名开始了

给你一些串组成的集合,每次询问一个串是否出现在这个集合里和之前是否询问过这个串。

先把集合插入 Trie 中,询问时顺着 Trie 往下走,直到走不了(没有儿子)或走完为止。

如何判断询问串究竟是串本身还是串的前缀呢?我们可以对插入的每个串的结尾做一个标记,则在标记处停止的询问串出现在了集合中,是否重复出现在标记处记录是否访问过即可。

在实现的过程中,需要对 Trie 树的空间格外注意,一般开 n\cdot|s| 个,其中 n 是字符串个数, |s| 是字符串的最大长度,这是因为每次插入一个串,最多可能多开 |s| 个节点(和之前的任何一个串都没有共同前缀)。这也就意味着一棵 Trie 树耗费的空间可能多达 n\cdot |s|\cdot 26 个 int,MLE 的风险很高,要尽量节省一些空间。

Code:

#include<bits/stdc++.h>
using namespace std;
const int N = 5e5+10;
int n,m,tr[N][26],tot;
bool exi[N],vis[N];
char s[N];
void insert(char *s){
    int len = strlen(s+1),now = 0;
    for (int i=1;i<=len;i++){
        int c = s[i]-'a';
        if (!tr[now][c]) tr[now][c] = ++tot;
        now = tr[now][c];
    }
    exi[now] = 1;
}
void query(char *s){
    int len = strlen(s+1),now = 0;
    for (int i=1;i<=len;i++){
        int c = s[i]-'a';
        if (!tr[now][c]){
            cout<<"WRONG"<<"\n";
            return;
        }
        now = tr[now][c];
    }
    if (!exi[now]) cout<<"WRONG"<<"\n";
    else if (vis[now]) cout<<"REPEAT"<<"\n";
    else{
        vis[now] = 1;
        cout<<"OK"<<"\n";
    }
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    cin>>n;
    for (int i=1;i<=n;i++){
        cin>>(s+1);
        insert(s);
    }
    cin>>m;
    for (int i=1;i<=m;i++){
        cin>>(s+1);
        query(s);
    }
    return 0;
}

在 Trie 上每次插入或查询一个字符串的复杂度都为 O(|s|)

Trie 树有很多有用的性质,比如一个串的前缀表示,一些串的最长公共前缀等,这使它能处理一些特殊问题。

例题1

UVA1401 Remember the Word

给你一些小字符串和一个大串,求用这些小串拼成大串的方案数。

一般求方案数的就是 DP,可以设计这样的 DP 状态:

f[i] 表示拼成前缀 i 的方案数,每次枚举 j,看 s[j~i] 这个串是否是一个小串,如果是,就加上 f[j-1]

但是,这样枚举 j 的复杂度很高(Hash 另说),我们不妨换一种枚举顺序:固定一个 i,枚举结尾 j,每次判断 s[i+1~j] 是否为一个小串,这样巧妙地利用了 Trie 存前缀的性质,每次扩展一个字符的时间复杂度为 O(1),由于小串的长度最大为 50,所以时间复杂度为 O(50n)

每次向右扩展一个字符可以往 Trie 上面想。

Code:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 4e5+10,mod = 20071027;
int n,m,tr[N][27],tot,vis[N][27],tim;
ll f[N];
bool exi[N];
char s[N],t[N];
void insert(char *s){
    int len = strlen(s+1),now = 0;
    for (int i=1;i<=len;i++){
        int c = s[i]-'a';
        if (vis[now][c]!=tim){
            vis[now][c] = tim;
            tr[now][c] = ++tot;
        }
        now = tr[now][c];
    }
    exi[now] = 1;
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    while (cin>>(s+1)){
        tim++;
        memset(exi,0,sizeof(bool)*(tot+10));
        tot = 0;
        n = strlen(s+1);
        memset(f,0,sizeof(ll)*(n+10));
        cin>>m;
        for (int i=1;i<=m;i++){
            cin>>(t+1);
            insert(t);
        }
        f[n+1] = 1;
        for (int i=n;i>=1;i--){
            int now = 0;
            for (int j=i;j<=n;j++){
                int c = s[j]-'a';
                if (vis[now][c]!=tim) break;
                now = tr[now][c];
                if (exi[now]) f[i] = (f[i]+f[j+1])%mod;
            }
        }
        cout<<"Case "<<tim<<": "<<f[1]<<"\n";
    }
    return 0;
}

例题2

P4551 最长异或路径

先转化一下题意:

算出每个点到根的异或值,则要求的最大异或路径等于两个点到根的异或的最大值,因为根到两点 LCA 的路径异或值抵消了。

转化为求 n 个数里面两个数异或的最大值。

把异或值拆成二进制,按位从高到低贪心,有和这一位的值相反的就取(因为后面的所有低位即使都是 1,也比这一位小)。

如何高效的完成这个算法呢?发现所有数的高位组成的数就是它的前缀,而我们就是在前缀上比较,这类问题可以放到 Trie 上做,达到 O(32n) 的复杂度。

Code:

#include<bits/stdc++.h>
using namespace std;
const int N = 1e5+10;
int n,head[N],tot,ans;
struct e{
    int next,to,w;
} edge[N*2];
struct Trie{
    int tr[N*26][3],tot;
    void update(int x){
        int now = 0;
        for (int i=31;i>=0;i--){
            int c = (x>>i)&1;
            if (!tr[now][c]) tr[now][c] = ++tot;
            now = tr[now][c];
        }
    }
    int query(int x){
        int now = 0,ans = 0;
        for (int i=31;i>=0;i--){
            int c = (x>>i)&1;
            if (tr[now][c^1]){
                now = tr[now][c^1];
                ans+=(1<<i);
            }
            else now = tr[now][c];
        }
        return ans;
    }
} t;
void add(int x,int y,int w){
    edge[++tot].to = y;
    edge[tot].next = head[x];
    edge[tot].w = w;
    head[x] = tot;
}
void dfs(int u,int fa,int sum){
    int v;
    t.update(sum);
    ans = max(ans,t.query(sum));
    for (int i=head[u];i;i = edge[i].next){
        v = edge[i].to;
        if (v == fa) continue;
        dfs(v,u,sum^edge[i].w);
    }
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    cin>>n;
    int x,y,z;
    for (int i=1;i<n;i++){
        cin>>x>>y>>z;
        add(x,y,z);
        add(y,x,z);
    }
    dfs(1,0,0);
    cout<<ans<<endl;
    return 0;
}

这种把数拆成二进制之后放到 Trie 上的叫做 01Trie,这道题只是理清了 01Trie 的基本架构,它还有一些非常好的性质和衍生的操作。

01Trie

支持的操作:

我们一步步实现。

插入和删除都是从低位到高位(其实算习惯吧,为了之后写代码方便)。

在每一个 Trie 的节点 now 上记录以下几个值:

ch[now][0/1] 指向它的儿子
w[now] 表示有多少个数在插入时经过了它
sum[now] 表示 now 这棵子树的异或和(但是是把 now 所在的层当作最低位,并不是真正的子树异或和)

怎么维护呢?先看代码吧(感觉代码可能比文字更好懂,反正文字也在代码里面)。

void insert(int &now,int x,int dep){ //引用方便修改ch[now][0/1]
    if (!now) now = push(); //新建节点
    if (dep>H){ //H为最大深度,一般要保证答案在2^H以内
        w[now]++;
        return;
    }
    insert(ch[now][x&1],x>>1,dep+1); //从低位到高位插入
    pushup(now); //维护 sum 和 w
}

再看看 pushup 操作。

void pushup(int now){
    w[now] = sum[now] = 0; //初始化为0
    if (ch[now][0]){ //如果把now->ch[now][0]看作一条边,那么0就相当于边权,代表第dep+1位为0
        w[now]+=w[ch[now][0]]; //经过儿子的数插入时都会经过父亲
        sum[now]^=(sum[ch[now][0]]<<1); //第dep+1位为0,对这一位的异或和没有影响
    }
    if (ch[now][1]){ //同上
        w[now]+=w[ch[now][1]]; //经过儿子的数插入时都会经过父亲
        sum[now]^=((sum[ch[now][1]]<<1)|(w[ch[now][1]]&1)); //第dep+1位为1,异或和只和1的奇偶性相关
    }
    w[now]&=1; //因为只和奇偶性相关,所以不用存具体值(要存也没关系,维护的信息会更具体,取决于题目)
}

erase 操作如法炮制。

void erase(int now,int x,int dep){
    if (dep>H){
        w[now]--; //如果不要求求具体数量的话,这里改成+1也行,或者直接不要erase,再insert一遍(异或两次等于没异或)
        return;
    }
    erase(ch[now][x&1],x>>1,dep+1);
    pushup(now);
}

更短的 addall(全局 +1)操作。

void addall(int now){
    swap(ch[now][0],ch[now][1]);
    if (ch[now][0]) addall(ch[now][0]);
    pushup(now);
}

这个操作具体解释一下:

每个数 +1,在二进制下表现为从最低位开始的为 1 的位都变成 0,第一个为 0 的位变成 1

再次思考父亲和儿子之间唯一的联系:边权。这一位是啥,边权就是啥,我们交换左右儿子就相当于交换边权。

最后,因为要修改的是为 1 的位,交换后变成了为 0 的位,只要还有,就继续往下修改。

查询所有数的异或和更简单了,直接输出 Trie 树的根的异或和即可。

例题3

P6018 Ynoi2010 Fusion tree

这种树上的维护,一般分成维护一个点的所有儿子,父亲单独维护。

操作 1 相当于所有儿子 +1,父亲 +1。儿子的 +1 直接 addall,父亲的 +1 单独维护,顺便把父亲的父亲的树一起维护掉,保证正确性。

操作 2 相当于自己 -v,先在父亲的 Trie 中把自己删除,再插入更新后的自己即可。

操作 3 直接输出对应 Trie 树根节点的异或和异或上父亲的值即可。

Code:

#include<bits/stdc++.h>
using namespace std;
const int N = 5e5+10,H = 21;
int n,m,head[N],tot,val[N],rt,fa[N],lazy[N];
struct e{
    int next,to;
} edge[N*2];
struct Trie{
    int rt[N],w[N*H],ch[N*H][2],sum[N*H],tot;
    int push(){
        tot++;
        ch[tot][0] = ch[tot][1] = 0;
        w[tot] = sum[tot] = 0;
        return tot;
    }
    void pushup(int node){
        w[node] = sum[node] = 0;
        if (ch[node][0]){
            w[node]+=w[ch[node][0]];
            sum[node]^=(sum[ch[node][0]]<<1);
        }
        if (ch[node][1]){
            w[node]+=w[ch[node][1]];
            sum[node]^=((sum[ch[node][1]]<<1)|(w[ch[node][1]]&1));
        }
        w[node]&=1;
    }
    void insert(int &now,int x,int dep){
        if (!now) now = push();
        if (dep>H){
            w[now]++;
            return;
        }
        insert(ch[now][x&1],x>>1,dep+1);
        pushup(now);
    }
    void erase(int now,int x,int dep){
        if (dep>H){
            w[now]--;
            return;
        }
        erase(ch[now][x&1],x>>1,dep+1);
        pushup(now);
    }
    void addall(int now){
        swap(ch[now][0],ch[now][1]);
        if (ch[now][0]) addall(ch[now][0]);
        pushup(now);
    }
} t;
void add(int x,int y){
    edge[++tot].to = y;
    edge[tot].next = head[x];
    head[x] = tot;
}
void dfs(int u,int f){
    fa[u] = f;
    int v;
    for (int i=head[u];i;i = edge[i].next){
        v = edge[i].to;
        if (v == f) continue;
        dfs(v,u);
    }
}
int get(int x){
    return val[x]+(fa[x]?lazy[fa[x]]:0);
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    cin>>n>>m;
    int x,y;
    for (int i=1;i<n;i++){
        cin>>x>>y;
        add(x,y);
        add(y,x);
    }
    dfs(rt = y,0);
    for (int i=1;i<=n;i++){
        cin>>val[i];
        if (fa[i]) t.insert(t.rt[fa[i]],val[i],0); //这里不用担心rt初始为0的问题,因为是引用,如果没有rt会自动创建一个 
    }
    int op;
    for (int i=1;i<=m;i++){
        cin>>op>>x;
        if (op == 1){
            lazy[x]++;
            if (fa[x]){
                if (fa[fa[x]]) t.erase(t.rt[fa[fa[x]]],get(fa[x]),0);
                val[fa[x]]++;
                if (fa[fa[x]]) t.insert(t.rt[fa[fa[x]]],get(fa[x]),0);
            }
            t.addall(t.rt[x]);
        }
        else if (op == 2){
            cin>>y;
            if (fa[x]) t.erase(t.rt[fa[x]],get(x),0);
            val[x]-=y;
            if (fa[x]) t.insert(t.rt[fa[x]],get(x),0);
        }
        else cout<<(t.sum[t.rt[x]]^(fa[x]?get(fa[x]):0))<<"\n";
    }
    return 0;
}

01Trie 合并

先看代码。

int merge(int a,int b){ //a和b为两个树上位置相同的节点
    if (!a) return b;
    if (!b) return a; //如果其中一个不存在,那它的子树也不存在,所以不用修改子树
    w[a]+=w[b];
    sum[a]^=sum[b]; //直接修改
    ch[a][0] = merge(ch[a][0],ch[b][0]);
    ch[a][1] = merge(ch[a][1],ch[b][1]); //更新两个儿子
    return a;
}

每次合并两个 01Trie 时,看成小的 Trie 往大的 Trie 里面合并,那么小的 Trie 中的元素所在集合大小就会翻倍,每次合并至少让一个 Trie 中的元素大小翻倍,最后所有元素的大小都到了 n(属于同一个大集合),所以一个元素被合并的次数为 \log n 次,时间复杂度为 O(n\log n)

例题4

P6623 省选联考 2020 A 卷 树

容易想到合并子树的 Trie,之后全局 +1,再加上自己,dfs 一遍,加上每个点的异或和就是答案。

Code:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 6e5+10,H = 23;
int n,val[N],head[N],tot;
ll ans;
struct e{
    int next,to;
} edge[N*2];
struct Trie{
    int ch[N*H][2],w[N*H],rt[N],tot;
    ll sum[N*H];
    int push(){
        tot++;
        return tot;
    }
    void pushup(int now){
        w[now] = sum[now] = 0;
        if (ch[now][0]){
            w[now]+=w[ch[now][0]];
            sum[now]^=(sum[ch[now][0]]<<1);
        }
        if (ch[now][1]){
            w[now]+=w[ch[now][1]];
            sum[now]^=((sum[ch[now][1]]<<1)|(w[ch[now][1]]&1));
        }
        w[now]&=1;
    }
    void insert(int &now,int x,int dep){
        if (!now) now = push();
        if (dep>H){
            w[now]++;
            return;
        }
        insert(ch[now][x&1],x>>1,dep+1);
        pushup(now);
    }
    void addall(int now){
        swap(ch[now][0],ch[now][1]);
        if (ch[now][0]) addall(ch[now][0]);
        pushup(now);
    }
    int merge(int a,int b){
        if (!a) return b;
        if (!b) return a;
        w[a]+=w[b];
        sum[a]^=sum[b];
        ch[a][0] = merge(ch[a][0],ch[b][0]);
        ch[a][1] = merge(ch[a][1],ch[b][1]);
        return a;
    }
} t;
void add(int x,int y){
    edge[++tot].to = y;
    edge[tot].next = head[x];
    head[x] = tot;
}
void dfs(int u,int fa){
    int v;
    for (int i=head[u];i;i = edge[i].next){
        v = edge[i].to;
        if (v == fa) continue;
        dfs(v,u);
        t.rt[u] = t.merge(t.rt[u],t.rt[v]);
    }
    t.addall(t.rt[u]);
    t.insert(t.rt[u],val[u],0);
    ans+=t.sum[t.rt[u]];
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    cin>>n;
    for (int i=1;i<=n;i++) cin>>val[i];
    int y;
    for (int i=2;i<=n;i++){
        cin>>y;
        add(i,y);
        add(y,i);
    }
    dfs(1,0);
    cout<<ans<<endl;
    return 0;
}

可持久化 Trie

一般来说,一种数据结构,只要在修改的过程中,之前的结构没有变化,那么它就能可持久化,即可以以原空间规模访问任意一个历史版本。

于是,有了可持久化 Trie。

我们要实现的功能如下:能够从一个节点开始,访问到插入第 i 个串前的 Trie 树。由于之前插入的子串是不会动的,我们可以利用之前的信息来完成可持久化。

以插入串 cab 为例(图 3)。

最开始 p 处于 1q 处于 8,我们先把 1->2 的边连上,因为 r\neq s[1],所以连了不会影响之后的插入。但是,我们不能把 9 那个点也一起连了,因为这样的话,之后的插入就要在那个点的子树中进行了,破坏了历史版本的信息,所以我们另开一个点 3

之后,像一般 Trie 的插入一样,p 走到 3q 也走到 9,发现 a = s[2],所以不能连到 10,而是另开一个点 5

最后 $p$ 走到 $7$,插入完成。 插入的过程可以概括为:维护这一个历史版本和上一个历史版本走到的位置,除了插入串的节点要新建之外,剩下的都指向原来的版本,直到串插入完毕。 ## 例题5 [P4735 最大异或和](https://www.luogu.com.cn/problem/P4735) 给你一个序列,有两个操作: 1. 加入一个数。 2. 询问左端点在 $[l,r]$ 之间,右端点为序列终点的数异或起来,再异或一个 $x$ 的最大值。 假设没有 $l,r$ 的限制,可以维护一个异或前缀和 $s$,之后就是求 $s[n]\oplus s[i-1]\oplus x$ 的最大值。把 $s[n]\oplus x$ 看作一个整体,就是在 $n$ 个数里面找两个数异或的最大值,这不就是例题2吗? 现在,我们加入一个限制 $r$,于是只能在 $1$~$r-1$ 中选一个 $s[i]$,发现这不就是只插入 $1$~$r-1$ 的 Trie 吗?查询 $r-1$ 的历史版本就能解决。 最后,加入一个限制 $l$,于是只能在 $l-1$~$r-1$ 中选一个 $s[i]$,我们需要对每个节点打上一个标记,代表它的子树中出现时间最晚的数的出现时间,在贪心的基础上只递归出现时间 $\geq l-1$ 的就行了。 ### Code: ```c++ #include<bits/stdc++.h> using namespace std; const int N = 6e5+10,H = 24; int n,m,s[N],tot; struct Trie{ int tr[N*22][2],tot,rt[N],lat[N*24]; void insert(int now,int las,int i,int dep){ if (dep<0){ lat[now] = i; return; } int c = (s[i]>>dep)&1; if (las) tr[now][c^1] = tr[las][c^1]; tr[now][c] = ++tot; insert(tr[now][c],tr[las][c],i,dep-1); lat[now] = max(lat[tr[now][0]],lat[tr[now][1]]); } int query(int now,int lim,int x,int dep){ if (dep<0) return s[lat[now]]^x; int c = (x>>dep)&1; if (lat[tr[now][c^1]]>=lim) return query(tr[now][c^1],lim,x,dep-1); else return query(tr[now][c],lim,x,dep-1); } } t; int main(){ ios::sync_with_stdio(false); cin.tie(0); cin>>n>>m; t.lat[0] = -1; t.rt[tot] = ++t.tot; t.insert(t.rt[tot],0,0,H); int x; for (int i=1;i<=n;i++){ cin>>x; tot++; s[tot] = s[tot-1]^x; t.rt[i] = ++t.tot; t.insert(t.rt[i],t.rt[i-1],tot,H); } char op; int l,r; for (int i=1;i<=m;i++){ cin>>op; if (op == 'A'){ cin>>x; tot++; s[tot] = s[tot-1]^x; t.rt[tot] = ++t.tot; t.insert(t.rt[tot],t.rt[tot-1],tot,H); } else{ cin>>l>>r>>x; cout<<t.query(t.rt[r-1],l-1,s[tot]^x,H)<<"\n"; } } return 0; } ```