P5092 [USACO04OPEN] Cube Stacking 题解

· · 题解

看到这道题大家几乎都写的带权并查集,这里提供另一种思路。

solution

我们模拟一下会发现,当将一个立方柱放到另一个上面时,这个立方柱上的所有方块下面的方块数会加上下面一个立方柱的高度。

此时,我们会想到进行区间加之类的操作将那个立方柱上的所有方块对应的答案进行更新,但这些方块编号不是连续的,因此我们要将所有方块编号进行映射成另一个编号从而使他们连续。我们可以把输入存下来,先用链表模拟一下整个搭的过程,当搭完毕的时候,我们再将同一条立方柱的方块从上到下映射成连续的编号。

为什么这么做呢?因为一个立方柱无法被拆只可能整体来移,所以,一个立方柱上的方块不会换顺序,并且一个方块只与在最后跟其在同一柱子上的方块可能在中途在同一柱子上。因此,我们就保证了每次搭上去的柱子的编号连续的。

剩下的就简单了,我们可以开一个区间修改单点查询的树状数组来维护每个方块下面的方块数,每次搭上去一个立方柱便将那个立方柱的所有方块下面的方块数用一次区间加更新,查询也直接查就行了。

code

#include<iostream>
#include<cstdio>
#define lowbit(x) ((x)&(-(x)))
using namespace std;
long long tree[30100];
int mp[30100];//映射
int head[30100],tail[30100],nxt[30100],siz[30100]; //head 方块所属柱子的顶部方块 tail 方块所属柱子的底部方块 nxt 方块下面的方块 siz 方块所属的柱子的高度但只保证查顶部顶部方块能得到正确值
int findh(int x){
    return (head[x]==x)?x:(head[x]=findh(head[x]));
}
void joinh(int x,int y){
    x=findh(x),y=findh(y);
    if(x!=y)head[x]=y;
    siz[y]+=siz[x];
    return;
}
int findt(int x){
    return (tail[x]==x)?x:(tail[x]=findt(tail[x]));
}
void joint(int x,int y){
    x=findt(x),y=findt(y);
    if(x!=y)tail[x]=y;
    return;
}
struct node{
    char op;
    int x,y;
}c[100010];
int n=0;
inline void modify(int x,int k){
    for(int i=x;i<=n;i+=lowbit(i)){
        tree[i]+=k;
    }
    return;
}
inline int query(int x){
    long long ans=0;
    for(int i=x;i>0;i-=lowbit(i)){
        ans+=tree[i];
    }
    return ans;
}
int main(){
    cin.tie(0);
    ios::sync_with_stdio(false);
    int m;
    cin>>m;
    for(int i=1;i<=30000;i++){
        head[i]=i,tail[i]=i;
        siz[i]=1;
    }
    for(int i=1;i<=m;i++){
        char op;
        int x,y=0;
        cin>>op>>x;
        if(op=='M'){
            cin>>y;
            nxt[findt(x)]=findh(y);
            joint(x,y);
            joinh(y,x);
        }
        c[i].op=op,c[i].x=x,c[i].y=y;
    }
    int cnt=0;
    for(int i=1;i<=30000;i++){
        if(head[i]==i){
            for(int p=i;p;p=nxt[p]){
                mp[p]=++n;
            }
        }
    }
    for(int i=1;i<=n;i++){
        head[i]=i,tail[i]=i;
        siz[i]=1;
    }
    for(int i=1;i<=m;i++){
        int op=c[i].op,x=c[i].x,y=c[i].y;
        if(op=='M'){
            int t=siz[findh(y)];
            modify(mp[findh(x)],t);
            modify(mp[findt(x)]+1,-t); 
            joint(x,y);
            joinh(y,x);
        }else{
            cout<<query(mp[x])<<"\n";
        }
    }
    cout<<flush;
    return 0;
}