```
#include<bits/stdc++.h>
using namespace std;
#define mid (l+r)/2
#define lc o*2
#define rc o*2+1
#define MAXN 100010
int n,m;
struct node {
int add,lec,ric,sum;
} a[4*MAXN];
int fir[2*MAXN],nex[2*MAXN],to[2*MAXN],tot;
int k1,k2,k3,k4;
char flag;
int nu[MAXN],num[MAXN],son[MAXN],id[MAXN],fa[MAXN],cnt,dep[MAXN],siz[MAXN],top[MAXN];
int read() {
int re=0,f=1;
char ch=getchar();
while (ch<'0' || ch>'9') {ch=getchar();if(ch=='-')f*=-1;}
while (ch>='0' && ch<='9') {
re=re*10+ch-'0';
ch=getchar();
}
return re*f;
}
///////////////////////////////////////////////////////////////////////////
void addedge(int kk1,int kk2) {
to[++tot]=kk2;
nex[tot]=fir[kk1];
fir[kk1]=tot;
to[++tot]=kk1;
nex[tot]=fir[kk2];
fir[kk2]=tot;
}
void pushup(int o) {
a[o].lec=a[lc].lec;
a[o].ric=a[rc].ric;
a[o].sum=a[lc].sum+a[rc].sum;
if(a[lc].ric==a[rc].lec)a[o].sum--;
}
void pushdown(int o) {
if(a[o].add) {
a[lc].add=a[rc].add=a[o].add;
a[lc].lec=a[rc].lec=a[o].add;
a[lc].ric=a[rc].ric=a[o].add;
a[lc].sum=a[rc].sum=1;
a[o].add=0;
}
}
///////////////////////////////////////////
void build(int o,int l,int r) {
if(l==r) {
a[o].sum=1;
a[o].lec=num[l];
a[o].ric=num[l];
return;
}
build(lc,l,mid);
build(rc,mid+1,r);
pushup(o);
}
/////////////////////////////////////////
void update(int o,int l,int r,int x,int y,int k) {
if(x<=l&&r<=y) {
a[o].sum=1;
a[o].lec=k;
a[o].ric=k;
a[o].add=k;
return;
}
pushdown(o);
if(x<=mid)update(lc,l,mid,x,y,k);
if(y>mid)update(rc,mid+1,r,x,y,k);
pushup(o);
}
/////////////////////////
int query(int o,int l,int r,int x,int y) {
if(x<=l&&r<=y)return a[o].sum;
pushdown(o);
int ans=0;
if(x<=mid)ans+=query(lc,l,mid,x,y);
if(y>mid)ans+=query(rc,mid+1,r,x,y);
if(a[lc].ric==a[rc].lec)ans--;
return ans;
}
/////////////////////////////////////////
void dfs1(int u,int f,int d) {
fa[u]=f;
siz[u]=1;
dep[u]=d;
for(int i=fir[u]; i; i=nex[i]) {
int v=to[i];
if(v==f)continue;
dfs1(v,u,d+1);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
/////////////////////////////////////////
void dfs2(int u,int topf) {
top[u]=topf;
id[u]=++cnt;
num[cnt]=nu[u];
if(!son[u])return;
dfs2(son[u],topf);
for(int i=fir[u]; i; i=nex[i]) {
int v=to[i];
if(v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
}
int qpoint(int o,int l,int r,int p) {
if(l==r)return a[o].lec;
pushdown(o);
if(p<=mid) return qpoint(lc,l,mid,p);
else return qpoint(rc,mid+1,r,p);
}
int qRange(int u,int v) {
int ans=0,fc1,fc2;
while(top[u]!=top[v]) {
if(dep[top[u]]<dep[top[v]])swap(u,v);
ans+=query(1,1,n,id[top[u]],id[u]);
fc1=qpoint(1,1,n,id[top[u]]);
fc2=qpoint(1,1,n,id[fa[top[u]]]);
u=fa[top[u]];
if(fc1==fc2)ans--;
}
if(dep[u]>dep[v])swap(u,v);
ans+=query(1,1,n,id[u],id[v]);
if(qpoint(1,1,n,id[u])==qpoint(1,1,n,id[v]))ans--;
return ans;
}
void qUpdate(int u,int v,int k) {
while(top[u]!=top[v]) {
if(dep[top[u]]<dep[top[v]])swap(u,v);
update(1,1,n,id[top[u]],id[u],k);
u=fa[top[u]];
}
if(dep[u]>dep[v])swap(u,v);
update(1,1,n,id[u],id[v],k);
}
int main() {
cin>>n>>m;
for(register int i=1; i<=n; i++)nu[i]=read();
for(register int i=1; i<n; i++) {
k1=read();
k2=read();
addedge(k1,k2);
}
dfs1(1,0,1);
dfs2(1,1);
build(1,1,n);
for(register int i=1; i<=m; i++) {
cin>>flag;
if(flag=='C') {
k1=read();
k2=read();
k3=read();
qUpdate(k1,k2,k3);
} else {
k1=read();
k2=read();
printf("%d\n",qRange(k1,k2));
}
}
}
```
改完后的WA代码
by zyywzw @ 2018-06-28 11:44:43