P14521 【MX-S11-T2】加减乘除

· · 题解

平衡树实在是太好用了。

对于每个询问,需要遍历整棵树,但是这显然超时,于是考虑把这些询问离线一起处理。

我们把这些询问看做一个集合,每次到一个新的节点,集合的整体大小就会加或者减当前节点 u 的权值 w[u],此时对于 u 的一个子节点 v,只有值在 [l_v,r_v] 之间的集合元素可以到达 v,剩下的集合元素我们从集合中删去,留在当前节点,当结束完对 v 的遍历后再重新加入集合。统计答案可以通过对集合内的元素打标记解决。

这个容易使用平衡树维护,对于整体大小的加减,我们可以额外维护一个变量表示集合内元素整体增加或减少了多少,整体复杂度 O(q\log q)

#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int N=1e6+10;
struct node{
    int ls,rs,sz,rd,x;
    int id,as,aslz;
}; 
node tr[N];
struct Node{
    int v,l,r;
};
vector<Node> g[N];
int w[N];
int rt,cnt;
inline int nt(int x,int id){
    tr[++cnt]={0,0,1,rand(),x,id,0,0};
    return cnt;
}
inline void up(int u){
    tr[u].sz=tr[tr[u].ls].sz+tr[tr[u].rs].sz+1;;
}
inline void down(int u){
    if(tr[u].aslz){
        if(tr[u].ls){
            tr[tr[u].ls].aslz+=tr[u].aslz;
            tr[tr[u].ls].as+=tr[u].aslz;
        }
        if(tr[u].rs){
            tr[tr[u].rs].aslz+=tr[u].aslz;
            tr[tr[u].rs].as+=tr[u].aslz;
        }
    }
    tr[u].aslz=0;
}
inline void spt(int u,int &x,int &y,int k){
    if(!u){
        x=y=0;
        return;
    }
    down(u);
    if(tr[u].x<=k){
        x=u;
        spt(tr[u].rs,tr[u].rs,y,k);
    }
    else{
        y=u;
        spt(tr[u].ls,x,tr[u].ls,k);
    }
    up(u);
}
inline int mg(int u,int v){
    if(!u||!v){
        return u+v;
    }
    down(u),down(v);
    if(tr[u].rd<tr[v].rd){
        tr[u].rs=mg(tr[u].rs,v);
        up(u);
        return u;
    }
    else{
        tr[v].ls=mg(u,tr[v].ls);
        up(v);
        return v;
    }
}
inline void ins(int x,int id){
    int l,r;
    spt(rt,l,r,x-1);
    rt=mg(l,mg(nt(x,id),r));
}
inline void dfs(int u,int nt,int prs){
    prs+=w[u];
    tr[nt].as++;
    tr[nt].aslz++;
    for(int i=0;i<g[u].size();i++){
        int v=g[u][i].v;
        int l=g[u][i].l;
        int r=g[u][i].r;
        int L,md,R;
        spt(nt,L,md,l-1-prs);
        spt(md,md,R,r-prs);
        dfs(v,md,prs);
        nt=mg(L,mg(md,R));
    }
}
int ans[N];
inline void dfst(int u){
    down(u);
    if(tr[u].ls){
        dfst(tr[u].ls);
    }
    ans[tr[u].id]=tr[u].as;
    if(tr[u].rs){
        dfst(tr[u].rs);
    }
}
signed main(){
    ios::sync_with_stdio(false);
    cin.tie(0);cout.tie(0);
    int n,q;cin>>n>>q;
    for(int i=2;i<=n;i++){
        int p,l,r;cin>>p>>l>>r;
        g[p].push_back({i,l,r});
    }
    for(int i=1;i<=n;i++){
        char x;cin>>x;
        if(x=='+'){
            cin>>w[i];
        }
        else{
            cin>>w[i],w[i]=-w[i];
        }
    }
    for(int i=1;i<=q;i++){
        int x;cin>>x;ins(x,i);
    }
    dfs(1,rt,0);
    dfst(rt);
    for(int i=1;i<=q;i++){
        cout<<ans[i]<<endl;
    }
    return 0;
}