题解:AT_abc461_e [ABC461E] E-liter

· · 题解

怎么写这么长是不是我想复杂了。

::::info[推导]{open}

考虑如何刻画一个格子的状态。

显然对于每个格子 (i, j),其颜色由对其进行的最后一次操作决定。

row_i 表示第 i 行最后一次被涂黑的操作编号。\ 设 col_j 表示第 j 列最后一次被涂白的操作编号。

因此有:格子 (i, j) 为黑色当且仅当 row_i> col_j

所以把查询的式子写出来:

ans = \sum_{i=1}^{N} \sum_{j=1}^{N} [row_i > col_j]

考虑固定 i\sum_{j=1}^{N} [row_i > col_j]col 中小于 row_i 的元素个数。

low(i) = \sum_{col_j<i} 1

所以需要维护 col_jrow_i,快速查询任意的 low\sum low({row_i})

::::

::::info[数据结构维护]{open} 用树状数组维护 col。\ 先将值 v 映射到下标 v+1 存储,让下标 xcol < x 的个数,然后 low(x) 直接用区间和维护即可。

用权值线段树维护 row。\ 线段树节点 [l, r]row_i 落在 [l, r] 内的数量和这些行的 low(row_i) 之和。

::::

::::info[维护操作]{open}

  1. 涂黑第 R 行:
    设当前操作编号为 t,记 old = row_Rnew = t。\ 该行原本贡献为 low(old),变为 low(new)。\ 因此在线段树中将值 old 的位置的行数减 1,总和减 low(old)。\ 同理在线段树中将值 new 的位置的行数加 1,总和加 low(new)。\ 更新 row_R = new
  2. 涂白第 C 列:
    设当前操作编号为 t,计 old = col_Cnew = t。 在树状数组中删去 old,加上 new。\ 更新 col_C = new。\ 对于任意 k \in (old, new]low(k) 减少 1。\ 在线段树上:对值域 [old+1, new] 区间减 1

每次操作后,\sum low(row_i) 即为线段树根节点的 sum,直接输出即可。

::::

::::success[代码]

时间复杂度 O(n \log n)

#include <bits/stdc++.h>
using namespace std;

#define int long long
#define left q*2
#define right q*2+1
#define zqj tre[q].l
#define yqj tre[q].r

inline int read()
{
    int x=0,f=1;char ch=getchar();
    while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
    while (ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar();}
    return x*f;
}
void write(int x)
{
    if(x<0)putchar('-'),x=-x;
    if(x<10)putchar(x+'0');
    else write(x/10),putchar(x%10+'0');
}
const int MAXQ=300005;
struct xds{
    int l;
    int r;
    int cnt;
    int sum;
    int add;
    void oin(int q,int p,int t,int c){
        l=q;
        r=p;
        sum=t;
        cnt=c;
        add=0;
    }
}tre[MAXQ*4];
int n,q;
int ti1[MAXQ+5];
int ti2[MAXQ+5];
int bit[MAXQ+5];

int lowbit(int n1){return n1&-n1;}

void bitadd(int idx,int k){
    for(;idx<=q+1;idx+=lowbit(idx)) bit[idx]+=k;
}
int bitsum(int idx){
    int s=0;
    for(;idx>0;idx-=lowbit(idx)) s+=bit[idx];
    return s;
}
void pushup(int q){
    tre[q].cnt=tre[left].cnt+tre[right].cnt;
    tre[q].sum=tre[left].sum+tre[right].sum;
}

void pushdown(int q){
    if(tre[q].add!=0){
        tre[left].sum+=tre[left].cnt*tre[q].add;
        tre[left].add+=tre[q].add;
        tre[right].sum+=tre[right].cnt*tre[q].add;
        tre[right].add+=tre[q].add;
        tre[q].add=0;
    }
}

void build(int q,int l,int r){
    tre[q].oin(l,r,0,0);
    if(l==r) return;
    int mid=(l+r)/2;
    build(left,l,mid);
    build(right,mid+1,r);
    pushup(q);
}

void update(int q,int pos,int cnt,int sum){
    if(zqj==yqj){
        tre[q].cnt+=cnt;
        tre[q].sum+=sum;
        return;
    }
    pushdown(q);
    int mid=(zqj+yqj)/2;
    if(pos<=mid) update(left,pos,cnt,sum);
    else update(right,pos,cnt,sum);
    pushup(q);
}

void change(int q,int ml,int mr,int k){
    if(ml>yqj||mr<zqj) return;
    if(ml<=zqj&&mr>=yqj){
        tre[q].sum+=tre[q].cnt*k;
        tre[q].add+=k;
        return;
    }
    pushdown(q);
    int mid=(zqj+yqj)/2;
    if(ml<=mid) change(left,ml,mr,k);
    if(mr>mid) change(right,ml,mr,k);
    pushup(q);
}

signed main(){
    n=read();
    q=read();
    build(1,0,q);
    for(int i=1;i<=n;i++){
        ti1[i]=0;
        update(1,0,1,0);
    }
    for(int j=1;j<=n;j++){
        ti2[j]=0;
        bitadd(1,1);
    }
    for(int t=1;t<=q;t++){
        int op,x;
        op=read();
        x=read();
        if(op==1){
            int la=ti1[x];
            int now=t;
            if(la==now){
                write(tre[1].sum);
                putchar('\n');
                continue;
            }
            int las=bitsum(la);
            int mqs=bitsum(now);
            update(1,la,-1,-las);
            update(1,now,1,mqs);
            ti1[x]=now;
        }else{
            int la=ti2[x];
            int now=t;
            if(la==now){
                write(tre[1].sum);
                putchar('\n');
                continue;
            }
            bitadd(la+1,-1);
            bitadd(now+1,1);
            ti2[x]=now;
            if(la+1<=now) change(1,la+1,now,-1);

        }
        write(tre[1].sum);
        putchar('\n');
    }
    return 0;
}

::::

通过记录