```cpp
#include<bits/stdc++.h>
#define int long long
using namespace std;
int read(){
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
return x*f;
}
struct node1{int val,id;}b[2000005];
struct node2{int l,r,id;}q[2000005];
struct node3{int nxt,to;}e[2000005];
int w,n,m,bl,qcnt,cntt,num,a[2000005],res,rt=1,ans[2000005],h[2000005],c[2000005],son[2000005],dfn[2000005],dep[2000005],f[2000005][25],cnt[2][2000005];
bool cmp1(node1 x,node1 y){return x.val<y.val;}
bool cmp2(node2 x,node2 y){return ((x.l-1)/bl)^((y.l-1)/bl)?x.l<y.l:x.r<y.r;}
void add_edge(int u,int v){e[cntt].to=v,e[cntt].nxt=h[u],h[u]=cntt++;}
void dfs(int x,int fa){
c[dfn[x]=++num]=x,son[x]=1,dep[x]=1+dep[fa],f[x][0]=fa;
for(int i=1;i<=20;i++)f[x][i]=f[f[x][i-1]][i-1];
for(int i=h[x];i^(-1);i=e[i].nxt){
if(e[i].to==fa)continue;
dfs(e[i].to,x),son[x]+=son[e[i].to];
}
}
int get(int x,int y){
if(dep[x]<dep[y])swap(x,y);
for(int i=20;i>=0;i--)if(dep[f[x][i]]>dep[y])x=f[x][i];
return x;
}
int lca(int x,int y){
if(dep[x]<dep[y])swap(x,y);
for(int i=20;i>=0;i--)if(dep[f[x][i]]>=dep[y])x=f[x][i];
if(x==y)return x;
for(int i=20;i>=0;i--){
if(f[x][i]^f[y][i])x=f[x][i],y=f[y][i];
}
return f[x][0];
}
void add2(int x){res+=cnt[1][a[c[x]]],++cnt[0][a[c[x]]];}
void add1(int x){res+=cnt[0][a[c[x]]],++cnt[1][a[c[x]]];}
void del2(int x){res-=cnt[1][a[c[x]]],--cnt[0][a[c[x]]];}
void del1(int x){res-=cnt[0][a[c[x]]],--cnt[1][a[c[x]]];}
signed main(){
n=read(),m=read();
bl=sqrt(n),memset(h,-1,sizeof h);
for(int i=1;i<=n;i++)b[i].val=read(),b[i].id=i;
sort(b+1,b+n+1,cmp1);
for(int i=1;i<=n;i++)w+=(b[i].val>b[i-1].val),a[b[i].id]=w;
for(int i=1,u,v;i<n;i++){
u=read(),v=read();
add_edge(u,v),add_edge(v,u);
}
dfs(1,0);
for(int i=n+1;i<=(n<<1);i++)c[i]=c[i-n];
for(int i=1,opt,x,y,l1,l2,r1,r2;i<=m;i++){
opt=read();
if(opt==1)rt=x=read(),ans[i]=-1;
else{
x=read(),y=read();
if(x==rt)l1=1,r1=n;
else if(x^lca(x,rt))l1=dfn[x],r1=dfn[x]+son[x]-1;
else{
int qq=get(x,rt);
l1=dfn[qq]+son[qq],r1=n+dfn[qq]-1;
}
if(y==rt)l2=1,r2=n;
else if(y^lca(y,rt))l2=dfn[y],r2=dfn[y]+son[y]-1;
else{
int qq=get(y,rt);
l2=dfn[qq]+son[qq],r2=n+dfn[qq]-1;
}
q[++qcnt].id=i,q[qcnt].l=r1,q[qcnt].r=r2;
if(q[qcnt].l>q[qcnt].r)swap(q[qcnt].l,q[qcnt].r);
q[++qcnt].id=-i,q[qcnt].l=l2-1,q[qcnt].r=r1;
if(q[qcnt].l>q[qcnt].r)swap(q[qcnt].l,q[qcnt].r);
q[++qcnt].id=-i,q[qcnt].l=l1-1,q[qcnt].r=r2;
if(q[qcnt].l>q[qcnt].r)swap(q[qcnt].l,q[qcnt].r);
q[++qcnt].id=i,q[qcnt].l=l1-1,q[qcnt].r=l2-1;
if(q[qcnt].l>q[qcnt].r)swap(q[qcnt].l,q[qcnt].r);
}
}
sort(q+1,q+qcnt+1,cmp2);
for(int i=1,l=0,r=0;i<=qcnt;i++){
while(q[i].l>l)add2(++l);
while(q[i].l<l)del2(l--);
while(q[i].r>r)add1(++r);
while(q[i].r<r)del1(r--);
ans[abs(q[i].id)]+=(q[i].id>0?res:-res);
}
for(int i=1;i<=m;i++)if(ans[i]>-1)cout<<ans[i]<<endl;
return 0;
}
```
by alex_liu @ 2023-04-25 14:00:42
%
by alex__liu @ 2023-04-25 14:03:13
@[alex_liu](/user/373355) 去掉```#define int long long```
by Nityacke @ 2023-04-25 14:13:36
楼上正解,亲测AC<https://www.luogu.com.cn/record/108910941>
by zjx331 @ 2023-04-25 19:54:52
但是这题得开 long long 罢
by seanlsy @ 2023-04-25 20:27:15
@[Ethereal_shadow](/user/568716) 已经关了,感谢您。
此贴结。
by alex_liu @ 2023-04-26 13:05:07