贴代码时请选择正确语言qwq
```cpp
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<map>
#include<time.h>
#define maxn 40005
#define maxm 100005
#define rr register
using namespace std;
struct edge
{
int to,nxt;
}e[maxn<<1];
struct node
{
int l,r,num;
}q[maxm];
int n,m,unit,tot,top,cnt,ans,l,r,LcA;
int head[maxn<<1],f[maxn][32],dep[maxn];
int a[maxn],data[maxn],belong[maxn],stack[maxn],vis[maxn],ANS[maxm],num[maxn];
inline int read()
{
int x=0,d=0;char c=getchar();
while(c>'9'||c<'0') d=c=='-',c=getchar();
while(c>='0'&&c<='9') x=(x<<1)+(x<<3)+c-'0',c=getchar();
return d?-x:x;
}
inline void out(int x)
{
if(x<0) putchar('-'),x=-x;
if(x>9) out(x/10);
putchar(x%10+'0');
}
inline void add(int u,int v)
{
e[++cnt]=(edge){v,head[u]},head[u]=cnt,
e[++cnt]=(edge){u,head[v]},head[v]=cnt;
}
inline void dfs(int u)
{
int t=top;
for(rr int i=1;(1<<i)<=dep[u];i++)
f[u][i]=f[f[u][i-1]][i-1];
for(rr int i=head[u];i;i=e[i].nxt) if(e[i].to!=f[u][0])
{
f[e[i].to][0]=u,
dep[e[i].to]=dep[u]+1,
dfs(e[i].to);
if(top-t>=unit)
{
++tot;
while(top>t) belong[stack[top--]]=tot;
}
}
stack[++top]=u;
}
inline void block_lca()
{
dfs(1);
while(top) belong[stack[top--]]=tot;
}
inline bool cmp(node a,node b)
{
return belong[a.l]==belong[b.l]?a.r<b.r:a.l<b.l;
}
inline void deal_q(int x)
{
(vis[x]^=1)?(ans+=(++num[a[x]]==1)):(ans-=(!--num[a[x]]));
}
inline void mo(int u,int v)
{
if(dep[u]<dep[v]) swap(u,v);
while(dep[u]>dep[v]) deal_q(u),u=f[u][0];
while(u!=v) deal_q(u),deal_q(v),u=f[u][0],v=f[v][0];
}
inline int lca(int u,int v)
{
if(dep[v]>dep[u]) swap(u,v);
for(rr int i=0;i<=16;++i) if((dep[u]-dep[v])&(1<<i))
u=f[u][i];
if(u==v) return u;
for(rr int i=16;i>=0;--i) if(f[u][i]!=f[v][i])
u=f[u][i],v=f[v][i];
return f[u][0];
}
int main()
{
clock_t t=clock();
n=read(),m=read(),unit=pow(n,0.666);
for(rr int i=1;i<=n;++i) a[i]=data[i]=read();
sort(data+1,data+n+1),
tot=unique(data+1,data+n+1)-data-1;
for(rr int i=1;i<=n;++i) a[i]=lower_bound(data+1,data+tot+1,a[i])-data;
for(rr int i=1,u,v;i<n;++i) u=read(),v=read(),add(u,v);
for(rr int i=1,u,v;i<=m;++i) u=read(),v=read(),q[i]=(node){u,v,i};
sort(q+1,q+m+1,cmp),
block_lca();
for(rr int i=1;i<=m;++i)
mo(l,q[i].l),mo(r,q[i].r),
l=q[i].l,r=q[i].r,
LcA=lca(l,r),
deal_q(LcA),
ANS[q[i].num]=ans,
deal_q(LcA);
for(rr int i=1;i<=m;++i) out(ANS[i]),puts("");
printf("%dms",clock()-t);
return 0;
}
```
by NaCly_Fish @ 2019-04-23 23:34:17
@[NaCly_Fish](/space/show?uid=115864) 抱歉那时候有点急就直接把代码放进去了qwq
by Jy_Amoy @ 2019-04-25 09:06:39
然后现在a掉了 昨晚回去仔细想了想觉得这种做法应该不可能比题解那些写法快 结果今天重写一遍就过了 后来发现之前的代码之所以那么慢是因为查询的区间左右端点存反了哭哭
by Jy_Amoy @ 2019-04-25 09:08:18