题解:P4883 mzf的考验

· · 题解

竟然找到了野生的平衡树题,初一蒟蒻来交一发题解。

平衡树练手好题。

简要题面

给定长度为 n 的序列 a,处理 m 次以下操作:

题目分析

这里选用无旋 Treap 解决此题。

翻转和求和显然是板子,那异或怎么办呢?
发现 d<2^{20},我们可以对于每一个节点,维护长度为 20 的序列 cnt。其中,cnt_i 表示第 i 位为 1 的个数。

如此处理,当我们异或一个数字 d 时,对于第 i 位:

我们就可以用 O(20) 时间求出区间和 S

S=\sum^{19}_{i=0} cnt_i\times 2^i

同时,我们设计两个懒标记,分别处理翻转和异或。

翻转和异或操作互不影响,翻转后,子树的异或标记依然生效

节点信息的存储如下:

struct Treap {
    int l,r;//左右孩子下标
    int x;//节点值
    int lazy1;//异或懒标记
    bool lazy2;//翻转懒标记
    int sum;//子树和
    int bit_cnt[20];//第i位为1的个数。
}tr[MAXN];

维护子树有关代码如下:

// 更新节点信息:子树大小、子树和、每个位上1的个数
void push_up(int u){
    tr[u].cnt=tr[tr[u].l].cnt+1+tr[tr[u].r].cnt;
    tr[u].sum=tr[u].x+tr[tr[u].l].sum+tr[tr[u].r].sum;
    for(int i=0;i<BIT;i++){
        tr[u].bit_cnt[i]=((tr[u].x>>i)&1)+tr[tr[u].l].bit_cnt[i]+tr[tr[u].r].bit_cnt[i];
    }
}
// 维护cnt和sum
void apply_xor(int u,int val){
    if(!u||!val)return;
    tr[u].lazy1^=val;  // 打标记
    for(int i=0;i<BIT;i++){
        if((val>>i)&1){
            tr[u].bit_cnt[i]=tr[u].cnt-tr[u].bit_cnt[i]; // 该位翻转
        }
    }
    int new_sum=0;
    for(int i=0;i<BIT;i++){
        new_sum+=tr[u].bit_cnt[i]<<i; // 按位重新计算和
    }
    tr[u].sum=new_sum;
    tr[u].x^=val;  // 当前节点的值也要异或
}
// 下传标记
void push_down(int u){
    if(tr[u].lazy2){  // 先翻转标记
        swap(tr[u].l,tr[u].r);
        if(tr[u].l)tr[tr[u].l].lazy2^=1;
        if(tr[u].r)tr[tr[u].r].lazy2^=1;
        tr[u].lazy2=0;
    }
    if(tr[u].lazy1){  // 后异或标记
        if(tr[u].l)apply_xor(tr[u].l,tr[u].lazy1);
        if(tr[u].r)apply_xor(tr[u].r,tr[u].lazy1);
        tr[u].lazy1=0;
    }
}

写出这一部分之后,剩下部分显然是板子,见代码:

pair<int,int> split_size(int u,int k){
        if(!u)return {0,0};
        push_down(u);
        if(tr[tr[u].l].cnt>=k){
            auto x=split_size(tr[u].l,k);
            tr[u].l=x.second;
            push_up(u);
            return {x.first,u};
        }else{
            auto x=split_size(tr[u].r,k-tr[tr[u].l].cnt-1);
            tr[u].r=x.first;
            push_up(u);
            return {u,x.second};
        }
    }
    int merge(int l,int r){
        if(!l||!r)return l|r;
        if(tr[l].s<tr[r].s){
            push_down(l);
            tr[l].r=merge(tr[l].r,r);
            push_up(l);
            return l;
        }else{
            push_down(r);
            tr[r].l=merge(l,tr[r].l);
            push_up(r);
            return r;
        }
    }
    void insert(int x){
        tr[++cnt]={0,0,x,rand(),1,0,0,0};
        for(int i=0;i<BIT;i++)tr[cnt].bit_cnt[i]=(x>>i)&1;
        tr[cnt].sum=x;
        root=merge(root,cnt);
    }
    void reverse(int l,int r){
        auto[a,b]=split_size(root,l-1);
        auto[c,d]=split_size(b,r-l+1);
        tr[c].lazy2^=1;
        root=merge(a,merge(c,d));
    }
    void xor_range(int l,int r,int val){
        auto[a,b]=split_size(root,l-1);
        auto[c,d]=split_size(b,r-l+1);
        apply_xor(c,val);
        root=merge(a,merge(c,d));
    }
    int query_sum(int l,int r){
        auto[a,b]=split_size(root,l-1);
        auto[c,d]=split_size(b,r-l+1);
        int res=tr[c].sum;
        root=merge(a,merge(c,d));
        return res;
    }

最终代码

::::info[Code]

#include <iostream>
#include <cstdlib>
#include <algorithm>
using namespace std;
#define int long long
#define MAXN 100010
#define BIT 20

struct FHQ_Treap_size{
    int cnt=0,root=0;
    struct Treap{
        int l,r,x,s,cnt;
        int lazy1;
        bool lazy2;
        int sum;
        int bit_cnt[20];
    }tr[MAXN];

    void push_up(int u){
        tr[u].cnt=tr[tr[u].l].cnt+1+tr[tr[u].r].cnt;
        tr[u].sum=tr[u].x+tr[tr[u].l].sum+tr[tr[u].r].sum;
        for(int i=0;i<BIT;i++){
            tr[u].bit_cnt[i]=((tr[u].x>>i)&1)+tr[tr[u].l].bit_cnt[i]+tr[tr[u].r].bit_cnt[i];
        }
    }
    void apply_xor(int u,int val){
        if(!u||!val)return;
        tr[u].lazy1^=val;
        for(int i=0;i<BIT;i++){
            if((val>>i)&1){
                tr[u].bit_cnt[i]=tr[u].cnt-tr[u].bit_cnt[i];
            }
        }
        int new_sum=0;
        for(int i=0;i<BIT;i++){
            new_sum+=tr[u].bit_cnt[i]<<i;
        }
        tr[u].sum=new_sum;
        tr[u].x^=val;
    }
    void push_down(int u){
        if(tr[u].lazy2){
            swap(tr[u].l,tr[u].r);
            if(tr[u].l)tr[tr[u].l].lazy2^=1;
            if(tr[u].r)tr[tr[u].r].lazy2^=1;
            tr[u].lazy2=0;
        }
        if(tr[u].lazy1){
            if(tr[u].l)apply_xor(tr[u].l,tr[u].lazy1);
            if(tr[u].r)apply_xor(tr[u].r,tr[u].lazy1);
            tr[u].lazy1=0;
        }
    }
    pair<int,int> split_size(int u,int k){
        if(!u)return {0,0};
        push_down(u);
        if(tr[tr[u].l].cnt>=k){
            auto x=split_size(tr[u].l,k);
            tr[u].l=x.second;
            push_up(u);
            return {x.first,u};
        }else{
            auto x=split_size(tr[u].r,k-tr[tr[u].l].cnt-1);
            tr[u].r=x.first;
            push_up(u);
            return {u,x.second};
        }
    }
    int merge(int l,int r){
        if(!l||!r)return l|r;
        if(tr[l].s<tr[r].s){
            push_down(l);
            tr[l].r=merge(tr[l].r,r);
            push_up(l);
            return l;
        }else{
            push_down(r);
            tr[r].l=merge(l,tr[r].l);
            push_up(r);
            return r;
        }
    }
    void insert(int x){
        tr[++cnt]={0,0,x,rand(),1,0,0,0};
        for(int i=0;i<BIT;i++)tr[cnt].bit_cnt[i]=(x>>i)&1;
        tr[cnt].sum=x;
        root=merge(root,cnt);
    }
    void reverse(int l,int r){
        auto[a,b]=split_size(root,l-1);
        auto[c,d]=split_size(b,r-l+1);
        tr[c].lazy2^=1;
        root=merge(a,merge(c,d));
    }
    void xor_range(int l,int r,int val){
        auto[a,b]=split_size(root,l-1);
        auto[c,d]=split_size(b,r-l+1);
        apply_xor(c,val);
        root=merge(a,merge(c,d));
    }
    int query_sum(int l,int r){
        auto[a,b]=split_size(root,l-1);
        auto[c,d]=split_size(b,r-l+1);
        int res=tr[c].sum;
        root=merge(a,merge(c,d));
        return res;
    }
}trs;

signed main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    srand(time(0));

    int n,m;
    cin >> n >> m;
    for(int i=1;i<=n;i++){
        int x;
        cin >> x;
        trs.insert(x);
    }

    while(m--){
        int opt,l,r,d;
        cin >> opt;
        if(opt==1){
            cin >> l >> r;
            trs.reverse(l,r);
        }else if(opt==2){
            cin >> l >> r >> d;
            trs.xor_range(l,r,d);
        }else{
            cin >> l >> r;
            cout << trs.query_sum(l,r) << '\n';
        }
    }
}

::::

时间复杂度

假设一次操作中,splitmerge 访问的节点数为 \log n,且下传异或标记需要 O(B) 时间(对于此题,B=20):

加上建树的 O(n\log n\times B) 总时间复杂度:

O((n+m)\times \log n\times B)

感谢大家观看,如果有事实性错误、疑问或者建议可以直接提在评论区,我会立即处理。

不过这种小题不会有人来看的吧。