鼠树 题解

· · 个人记录

首先一眼毒瘤数据结构
这个题大概支持这么几个操作:
1.单点询问点权
2.小修改,把一个点管辖点加上一个权值
3.询问子树内所有黑点的所有管辖点点权之和
4.大修改,一个子树内所有点都来个2
5.把黑点变白
6.把白点变黑
树上问题基本上肯定至少要dfs序线段树,关键是具体怎么实现
逐个操作分析:
由于涉及到颜色变化,而且每个点的管辖点分布所在区间不连续,所以如果直接维护的(把点权实在地加到每个点上)话,存在比较大的困难,一种做法是线段树合并+线段树分裂,太弱不会,于是仔细研究了一夜题解并调了一天,嘎
一个关键的操作是,对于一个白点,把他的权值累加到他的管辖点上,而不是加到他自己上,这样对于1操作直接查询他的管辖点,对于3操作用权值乘他管辖的点数就行了,对于2操作直接单点修改,这样我们使操作成功简化了
首先我们需要实现一个东西,可以快速求每个点的归属点,这个最暴力的方法就是一直跳父亲,实际上可以一个log实现,具体是进行树剖,维护每条重链,对于每个重链开一个set,把每个黑点插进他topset内,维护的时候首先往上跳top,然后在对应的set里二分,复杂度每次一个log,一共一个log
对每个黑点维护一下它的权值w和它管辖的节点个数c,具体:
5操作:这个操作一般在最前面,如果你把一个点x变了颜色,那么它会接管下面一部分白点,这些点本来被x的归属点管辖,所以x的归属点所管辖的点数会减少,步骤:

$2.$求出$x$应该管辖的点,具体就是用$x$子树中的所有点减去子树中$c$的和即为新的$c $4.$用新的$c$,$w$更新$x$,使其变成黑点 4操作:区间修改$w$,注意到只有黑点有意义,但是最终我们要$w\times c$,所以白点一定没有多余的贡献,直接子树区间修改 2操作:单点修改权值,没啥好说的 1操作:找到他的归属点,返回这个点的权值 3操作:查询子树内$\sum c\times w

6操作:最恶心,先删集合,然后你要把这个点的信息都被他的归属点接管,但他的权值还要保留,可以维护一个差值,另开一颗数记录,然后做一个容斥搞回去,具体:

1.$把$x$从对应集合删除,求出他的归属点$y$以及$w_x-w_y $3.$进行一次$4 \ x \ w_y-w_x$把多余的减去 $4.$最后把$c_y+=c_x$,并把$x$信息清空,被接管 于是要实现三个数据结构: 一个树剖+$set$,支持快速跳链找归属 一个简单线段树(dfs序),支持区间加和区间查询,也能用树状数组,但区修区查似乎不太好写 一个高级线段树,支持单点改$c$,单点改$w$,区间查$\sum c$,单点查$w$,区间加$w$,区间查$\sum c\times w

细节主要在找归属

inline int getfrom(int x)
{
   while(s[top[x]].empty()||d[*s[top[x]].rbegin()]>d[x])x=fa[top[x]];
   return *s[top[x]].lower_bound(x);
} 

注意每次动态更新x,和新的x比较,这样错的:

inline int getfrom(int x)
{
   int p=d[x];
   while(s[top[x]].empty()||d[*s[top[x]].rbegin()]>=p)x=fa[top[x]];
   return *s[top[x]].lower_bound(x);
}  

因为一个性质是:
一个点到根的路径是由若干轻链和若干重链的一部分组成的
每次跳的轻链下面一段可能已经不合法了,所以只留上面的
300行完整版

#include <bits/stdc++.h>
using namespace std;
#define int unsigned int
inline int read()
{
   int x=0;char ch=getchar();
   while(ch<'0'||ch>'9')ch=getchar();
   while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
   return x;
}
const int N=300050;
struct node{
   int from,to,next;
}a[N];
int head[N],mm=1;
inline void add(int x,int y)
{
   a[mm].from=x;a[mm].to=y;
   a[mm].next=head[x];head[x]=mm++;
}
int d[N],size[N],fa[N],son[N];
int top[N],id[N],rk[N],tot;
void dfs1(int x)
{
   for(int i=head[x];i;i=a[i].next)
   {
      int y=a[i].to;
      d[y]=d[x]+1;
      fa[y]=x;dfs1(y);
      size[x]+=size[y];
      if(size[y]>size[son[x]])
         son[x]=y;
   }
   size[x]++;
}
void dfs2(int x,int h)
{
   if(!x)return;
   top[x]=h;
   id[x]=++tot;
   rk[tot]=x;
   dfs2(son[x],h);
   for(int i=head[x];i;i=a[i].next)
   {
      int y=a[i].to;
      if(y==son[x])continue;
      dfs2(y,y);
   }
}
struct cmp{
  bool operator ()(const int &a, const int &b)
    {return d[a]>d[b];}
};
set <int,cmp> s[N];
inline int getfrom(int x)
{
   while(s[top[x]].empty()||d[*s[top[x]].rbegin()]>d[x])x=fa[top[x]];
   return *s[top[x]].lower_bound(x);
}
struct SBtree{
   int l,r,sum,lazy;
}trr[4*N];
void gabuild(int id,int l,int r)
{
   trr[id].l=l;trr[id].r=r;
   if(l==r)return;
   int mid=(l+r)>>1;
   gabuild(id*2,l,mid);
   gabuild(id*2+1,mid+1,r);
}
inline void galuo(int id)
{
   if(trr[id].r!=trr[id].l&&trr[id].lazy)
   {
      trr[id*2].lazy+=trr[id].lazy;
      trr[id*2+1].lazy+=trr[id].lazy;
      trr[id*2].sum+=trr[id].lazy*(trr[id*2].r-trr[id*2].l+1);
      trr[id*2+1].sum+=trr[id].lazy*(trr[id*2+1].r-trr[id*2+1].l+1);
      trr[id].lazy=0;
   }
}
void gaadd(int id,int l,int r,int v)
{
   if(l<=trr[id].l&&r>=trr[id].r)
   {
      trr[id].sum+=v*(trr[id].r-trr[id].l+1);
      trr[id].lazy+=v;return;
   }
   galuo(id);int mid=(trr[id].l+trr[id].r)>>1;
   if(r<=mid)gaadd(id*2,l,r,v);
   else if(l>mid)gaadd(id*2+1,l,r,v);
   else gaadd(id*2,l,mid,v),gaadd(id*2+1,mid+1,r,v);
   trr[id].sum=trr[id*2].sum+trr[id*2+1].sum;
}
int gagetsum(int id,int l,int r)
{  
   if(l<=trr[id].l&&r>=trr[id].r)return trr[id].sum;
   galuo(id);int mid=(trr[id].l+trr[id].r)>>1;
   if(r<=mid)return gagetsum(id*2,l,r);
   if(l>mid)return gagetsum(id*2+1,l,r);
   return gagetsum(id*2,l,mid)+gagetsum(id*2+1,mid+1,r);
}
struct NBtree{
   int l,r,c,w,sum,lazy;
}tr[4*N];
inline void qi(int id)
{  
   tr[id].c=tr[id*2].c+tr[id*2+1].c;
   tr[id].w=tr[id*2].w+tr[id*2+1].w;
   tr[id].sum=tr[id*2].sum+tr[id*2+1].sum;
}
inline void luo(int id)
{
   if(tr[id].l!=tr[id].r&&tr[id].lazy)
   {  
      tr[id*2].w+=tr[id].lazy*(tr[id*2].r-tr[id*2].l+1);
      tr[id*2+1].w+=tr[id].lazy*(tr[id*2+1].r-tr[id*2+1].l+1);
      tr[id*2].lazy+=tr[id].lazy;tr[id*2+1].lazy+=tr[id].lazy;
      tr[id*2].sum+=tr[id].lazy*tr[id*2].c;
      tr[id*2+1].sum+=tr[id].lazy*tr[id*2+1].c;
      tr[id].lazy=0;
   }
}
void build(int id,int l,int r)
{
   tr[id].l=l;tr[id].r=r;
   if(l==r)return;
   int mid=(tr[id].l+tr[id].r)>>1;
   build(id*2,l,mid);build(id*2+1,mid+1,r);
}
void changec(int id,int p,int v)
{
   if(tr[id].l==tr[id].r)
   {
      tr[id].c=v;tr[id].sum=tr[id].c*tr[id].w;
      return;
   }
   luo(id);int mid=(tr[id].l+tr[id].r)>>1;
   if(p<=mid)changec(id*2,p,v);
   else changec(id*2+1,p,v);
   qi(id);
}
void addc(int id,int p,int v)
{
   if(tr[id].l==tr[id].r)
   {
      tr[id].c+=v;tr[id].sum=tr[id].c*tr[id].w;
      return;
   }
   luo(id);int mid=(tr[id].l+tr[id].r)>>1;
   if(p<=mid)addc(id*2,p,v);
   else addc(id*2+1,p,v);
   qi(id);
}
void changew(int id,int p,int v)
{
   if(tr[id].l==tr[id].r)
   {
      tr[id].w=v;tr[id].sum=tr[id].c*tr[id].w;
      return;  
   }
   luo(id);int mid=(tr[id].l+tr[id].r)>>1;
   if(p<=mid)changew(id*2,p,v);
   else changew(id*2+1,p,v);
   qi(id);
}
void addw(int id,int p,int v)
{
   if(tr[id].l==tr[id].r)
   {
      tr[id].w+=v;tr[id].sum=tr[id].c*tr[id].w;
      return;  
   }
   luo(id);int mid=(tr[id].l+tr[id].r)>>1;
   if(p<=mid)addw(id*2,p,v);
   else addw(id*2+1,p,v);
   qi(id);
}
int getsumc(int id,int l,int r)
{
   if(l<=tr[id].l&&r>=tr[id].r)return tr[id].c;
   luo(id);int mid=(tr[id].l+tr[id].r)>>1;
   if(r<=mid)return getsumc(id*2,l,r);
   if(l>mid)return getsumc(id*2+1,l,r);
   return getsumc(id*2,l,mid)+getsumc(id*2+1,mid+1,r);
}
int getw(int id,int p)
{
   if(tr[id].l==tr[id].r)return tr[id].w;
   luo(id);int mid=(tr[id].l+tr[id].r)>>1;
   if(p<=mid)return getw(id*2,p);
   else return getw(id*2+1,p);
}
int getsum(int id,int l,int r)
{  
   if(l<=tr[id].l&&r>=tr[id].r)return tr[id].sum;
   luo(id);int mid=(tr[id].l+tr[id].r)>>1;
   if(r<=mid)return getsum(id*2,l,r);
   if(l>mid)return getsum(id*2+1,l,r);
   return getsum(id*2,l,mid)+getsum(id*2+1,mid+1,r);
}
void ganadd(int id,int l,int r,int v)
{  
   if(l<=tr[id].l&&r>=tr[id].r)
   {
      tr[id].w+=v*(tr[id].r-tr[id].l+1);
      tr[id].sum+=v*tr[id].c;
      tr[id].lazy+=v;return;
   }
   luo(id);int mid=(tr[id].l+tr[id].r)>>1;
   if(r<=mid)ganadd(id*2,l,r,v);
   else if(l>mid)ganadd(id*2+1,l,r,v);
   else ganadd(id*2,l,mid,v),ganadd(id*2+1,mid+1,r,v);
   qi(id);
} 
inline void gan1(int x)
{
   int gui=getfrom(x);
   printf("%u\n",getw(1,id[gui])+gagetsum(1,id[x],id[x]));
}
inline void gan2(int x,int v)
{
   addw(1,id[x],v);
}
inline void gan3(int x)
{
   int ans=getsum(1,id[x],id[x]+size[x]-1)+gagetsum(1,id[x],id[x]+size[x]-1);
   int p=size[x]-getsumc(1,id[x],id[x]+size[x]-1),gui=getfrom(x);
   ans+=p*getw(1,id[gui]);printf("%u\n",ans);
}
inline void gan4(int x,int v)
{
   ganadd(1,id[x],id[x]+size[x]-1,v);
}
inline void gan5(int x)
{
   int gui=getfrom(x);s[top[x]].insert(x);
   int p=size[x]-getsumc(1,id[x],id[x]+size[x]-1),ww=getw(1,id[gui]);
   changec(1,id[x],p);changew(1,id[x],ww);addc(1,id[gui],-p);
}
inline void gan6(int x)
{
   assert(s[top[x]].find(x)!=s[top[x]].end());
   s[top[x]].erase(x);int gui=getfrom(x);
   int ww=getw(1,id[x])-getw(1,id[gui]),p=getsumc(1,id[x],id[x]);
   changec(1,id[x],0);changew(1,id[x],0);
   gaadd(1,id[x],id[x]+size[x]-1,ww);
   gan4(x,-ww);addc(1,id[gui],p);
}
signed main()
{
   int n=read(),m=read();
   for(int i=2;i<=n;i++)
   {
      int x=read();
      add(x,i);
   }
   d[1]=1;dfs1(1);
   dfs2(1,1);s[1].insert(1);
   gabuild(1,1,tot);
   build(1,1,tot);addc(1,1,n);
   for(int i=1;i<=m;i++)
   {
      int op=read();
      if(op==1)
      {
         int x=read();
         gan1(x);
      }
      if(op==2)
      {
         int x=read(),v=read();
         gan2(x,v);
      }
      if(op==3)
      {
         int x=read();
         gan3(x);
      }
      if(op==4)
      {
         int x=read(),v=read();
         gan4(x,v);
      }
      if(op==5)
      {
         int x=read();
         gan5(x);
      }
      if(op==6)
      {
         int x=read();
         gan6(x);
      }
   }
   return 0;
}

确实锻炼码力