【牛客】XOR TREE(树链剖分)
90nwyn
2020-11-30 20:42:12
[题目链接](https://ac.nowcoder.com/acm/contest/4090/F)
------------
考虑路径$<u,v>$上每一个点的贡献,设路径上点的个数为$k$,分别为$v_1,v_2...v_k$
若$k$为偶数,答案为所有点的权值异或和
若$k$为奇数,答案为所有下标为偶数的点的权值异或和
考虑树链剖分,
线段树分别维护深度为奇数的点和偶数的点其权值异或和即可
------------
```cpp
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int M=4e5+5;
int n,q,a[M],head[M],ver[M],nxt[M],tot,f[M],dep[M],son[M],siz[M],top[M],id[M],cnt,rk[M];
struct d
{
int l,r,v[2];
#define ls i<<1
#define rs i<<1|1
}tree[M*4];
void add(int x,int y)
{
nxt[++tot]=head[x];
head[x]=tot;
ver[tot]=y;
}
void dfs1(int x,int fa)
{
f[x]=fa;
dep[x]=dep[fa]+1;
siz[x]=1;
for(int i=head[x];i;i=nxt[i])
{
int y=ver[i];
if(y==fa)continue;
dfs1(y,x);
siz[x]+=siz[y];
if(siz[y]>siz[son[x]])son[x]=y;
}
}
void dfs2(int x,int t)
{
top[x]=t;
id[x]=++cnt;
rk[cnt]=x;
if(son[x])dfs2(son[x],t);
for(int i=head[x];i;i=nxt[i])
{
int y=ver[i];
if(y==f[x]||y==son[x])continue;
dfs2(y,y);
}
}
void up(int i)
{
for(int j=0;j<2;j++)
tree[i].v[j]=tree[ls].v[j]^tree[rs].v[j];
}
void build(int i,int l,int r)
{
tree[i].l=l;tree[i].r=r;
if(l==r)return tree[i].v[dep[rk[l]]%2]=a[rk[l]],void();
int mid=(l+r)/2;
build(ls,l,mid);
build(rs,mid+1,r);
up(i);
}
int que(int i,int l,int r,int k)
{
if(tree[i].l==l&&tree[i].r==r)return tree[i].v[k];
int mid=(tree[i].l+tree[i].r)/2;
if(r<=mid)return que(ls,l,r,k);
else if(l>mid)return que(rs,l,r,k);
return que(ls,l,mid,k)^que(rs,mid+1,r,k);
}
void upd(int i,int pos,int x)
{
if(tree[i].l==tree[i].r)return tree[i].v[dep[rk[pos]]%2]=x,void();
int mid=(tree[i].l+tree[i].r)/2;
if(pos<=mid)upd(ls,pos,x);
else upd(rs,pos,x);
up(i);
}
int calc(int x,int y)
{
int odd=0,even=0,tx=x,ty=y;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
odd^=que(1,id[top[x]],id[x],1);
even^=que(1,id[top[x]],id[x],0);
x=f[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
odd^=que(1,id[x],id[y],1);
even^=que(1,id[x],id[y],0);
if(dep[tx]%2!=dep[ty]%2)return odd^even;
else if(dep[tx]%2)return even;
return odd;
}
int main()
{
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++)scanf("%d",&a[i]);
for(int i=1;i<=n-1;i++)
{
int x,y;scanf("%d%d",&x,&y);
add(x,y);add(y,x);
}
dfs1(1,0);
dfs2(1,1);
build(1,1,n);
while(q--)
{
int op,u,v;scanf("%d%d%d",&op,&u,&v);
if(op==1)upd(1,id[u],v);
else printf("%d\n",calc(u,v));
}
return 0;
}
```