一种优化版本树上回溯内存开销的 trick
xzf_200906 · · 算法·理论
先吐槽一下某人机模拟赛,标程空间复杂度错误但由于数据随机可以通过(虽然确实很难卡),被我用该方法爆标。
引入
实现一个长度为
n 的可持久化数组,有m 个版本,在每个版本上会有若干个修改,总修改次数为q 。此时使用可持久化线段树或版本树上朴素回溯均会产生不低于O(q) 的空间复杂度,而q 最高可达O(nm) 级别,在某些情况下会被卡空间。而该算法仅需O(n\log m) 的空间复杂度。算法
原型
该算法无论有多少个操作都会产生
O(nm) 的时间复杂度,故仅适用于n 较小而操作次数极多的场景,如可持久化的根号分治。
由于该问题在最坏情况下会将整个数组修改一遍,故考虑不使用回溯栈回溯,而是在每个节点将当前版本的数组记录下来。但发现当即将递归某个版本的最后一个儿子时,该节点无需记录状态,贪心地,这个儿子显然最好是重儿子,则该算法的步骤如下:
- 处理完当前节点的所有操作,并记录当前数组。
- 遍历所有轻儿子,每次遍历完将数组恢复到该节点记录的状态。
- 删除该节点记录的状态。
- 遍历重儿子。
由于每个点到根节点的路径上至多有
改进型
该算法将上述
当
代码
以下是【模板】可持久化线段树 1(可持久化数组)的 AC 代码,使用上述改进型算法,效率较为优秀。
由于此题中的操作次数较少,可能难以触发重构部分(即代码中的
reBuild())。但可以调小重构阈值以测试其正确性(即代码中的//Here)。另:严重怀疑这题数据是随的,就算将重构阈值调至
n/100也不会在后两个 Subtask 触发重构。#include <bits/stdc++.h> #define LL long long using namespace std; int op[1000005],pos[1000005],val[1000005],ans[1000005],topS[25],cap,n,m; int cur[1000005],tmp[1000005],siz[1000005],son[1000005],top=0; pair<int,int> his[25][1000005]; vector<int> e[1000005]; void pre(int p){ siz[p]=1; for(auto it:e[p]){ pre(it); siz[p]+=siz[it]; if(!son[p]||siz[son[p]]<siz[it]) son[p]=it; } } void reBuild(int p){ for(int i=1;i<=n;i++) tmp[i]=cur[i]; pair<int,int> *ed=his[p]+topS[p],*st=his[p]; while(ed!=st){ tmp[ed->first]=ed->second; ed--; } topS[p]=-1; int *poi=tmp+1; for(int i=1;i<=n;i++){ st->first=*poi; st++; poi++; } } void callBack(int p){ if(topS[p]==-1){ int *poi=cur+1; pair<int,int> *st=his[p]; for(int i=1;i<=n;i++){ *poi=st->first; st++; poi++; } } else{ pair<int,int> *ed=his[p]+topS[p],*st=his[p]; while(ed!=st){ cur[ed->first]=ed->second; ed--; } } topS[p]=0; } void solve(int p,int tp){ if(op[p]==1){ if(topS[tp]!=-1) his[tp][++topS[tp]]={pos[p],cur[pos[p]]}; cur[pos[p]]=val[p]; if(topS[tp]>=cap) reBuild(tp); } else ans[p]=cur[pos[p]]; int now=++top; topS[now]=0; for(auto it:e[p]){ if(it!=son[p]){ solve(it,now); callBack(now); } } top--; if(son[p]) solve(son[p],tp); } int main(){ ios::sync_with_stdio(false); cin.tie(0); cout.tie(0); cin>>n>>m; cap=n;//Here for(int i=1;i<=n;i++) cin>>cur[i]; for(int i=1;i<=m;i++){ int v; cin>>v>>op[i]>>pos[i]; if(op[i]==1) cin>>val[i]; e[v].push_back(i); } pre(0); solve(0,0); for(int i=1;i<=m;i++){ if(op[i]==2) cout<<ans[i]<<'\n'; } return 0; }