P5588 小猪佩奇爬树
Liujy_bc
·
·
题解
题目大意
给定一棵树,树中的每个节点都有一个颜色,对于 1~n 中的每一种颜色,输出这棵树中包含所有这种颜色的节点的链的个数。
对于每种颜色,我们可以看到有以下三种情况:
$2.$这种颜色的节点只有一个。
$3.$这种颜色的节点大于一个。
接下来我们分类讨论。
对于第一种情况,相当于答案对于这条链没有任何要求,因为这是一棵树所以任选两个点作为链的端点,这样的链存在且存在唯一的一条。所以我们最后的答案就是 $C_n^2$。
对于第二种情况,相当于我们要求这棵树上经过当前节点的链的个数。我们考虑怎样的链一定会经过当前节点。我们可以把当前节点当做整棵树的根,显然只有当链的端点分别属于他的两棵子树或者是他本身时才满足条件,我们把他的每棵子树的节点的个数两两相乘,最后再加上他所有子树的节点数的和。
这个东西我们可以在一遍 dfs 里面维护出每一个节点的答案。我们用 $sz_i$ 表示以节点1为根时,以节点 $i$ 为根的子树大小,那么如果以 $i$ 为根的话,他的其他子树都已经处理好了,还没有处理的那一棵的子树大小其实就是 $n-sz_i$ 。具体可以看后面代码实现。
下面来看第三种情况。对于一种颜色有多个点,首先我们需要判断这些点是否在一条链上,如果不在说明不存在一条链满足可以覆盖这种颜色的点。如何判断这些点是否在一条链上呢? 我们可以看一下这道题 [CF1702G2 Passable Paths (hard version)](https://www.luogu.com.cn/problem/CF1702G2)。具体思路就是先找到这些点中深度最深的点设这个点为 $x$,然后再找到同颜色的点中与这个点距离最远的点设为 $y$,如果这些点在同一条链上那么 $x$ 到 $y$ 的距离就是可能的链的最小距离。接下来我们再枚举其余节点到 $x$ 以及到 $y$ 的距离,如果这两个距离接起来不等于 $x$ 到 $y$ 的距离,说明这些点不在一条链上,可以直接输出0。距离可以直接用 LCA 求解。
对于在同一条链上的点又可以分为两种情况。
$1.$ $y$ 是 $x$ 的祖先。
此时的链的两个端点必然是一个在以 $x$ 为根的子树中个数就是前面处理出来的 $sz_x$ ,另一个在除去以在 $x$ 和 $y$ 之间除 $y$ 以外深度最浅的点为根的子树中的点之中。答案就是这两个数相乘。来看下面的图片。

我们现在处理的是的绿色的节点,$y$ 是3,$x$ 是6,在 $x$ 和 $y$ 之间除 $y$ 以外深度最浅的点是5,答案就是 $sz_6\times(n-sz_5)$ 。然后这个点可以利用处理 LCA 时用的倍增数组处理。
$2.$ $y$ 不是 $x$ 的祖先。
这个就比较简单啦,答案就是 $sz_x\times sz_y$。可以看下面这张图理解一下。

我们要处理蓝色的点,$x$ 是5,$y$ 是2,答案就是 $sz_x\times sz_y$。
下面是代码。
```cpp
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<vector>
using namespace std;
int n,m,q,tot,sz[1000005],a[1000005],d[1000005],c[1000005],fa[1000005][20];
int head[1000005],cnt;vector<int> v[1000005];long long ans[1000005];
struct node{
int to,next;
}e[2000005];
void add(int u,int v){e[++cnt]=((node){v,head[u]});head[u]=cnt;}
void dfs(int u){
for(int j=1;j<20;j++)
fa[u][j]=fa[fa[u][j-1]][j-1];
sz[u]=1;
for(int i=head[u];i;i=e[i].next){
int v=e[i].to;
if(v==fa[u][0])continue ;
d[v]=d[u]+1,fa[v][0]=u;
dfs(v);
ans[u]+=1ll*sz[u]*sz[v];
sz[u]+=sz[v];
}
ans[u]+=sz[u]*(n-sz[u]);
}//预处理求LCA的倍增数组以及sz数组
int up(int x,int d){
for(int i=0;d;i++){
if(d%2==1)x=fa[x][i];
d/=2;
}
return x;
}
int lca(int u,int v){
if(d[u]<d[v])swap(u,v);
u=up(u,d[u]-d[v]);
if(u==v)return u;
for(int i=19;i>=0;i--){
if((1<<i)<=d[u]&&fa[u][i]!=fa[v][i])
u=fa[u][i],v=fa[v][i];
}
return fa[u][0];
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d",&c[i]);
v[c[i]].push_back(i);
}
for(int i=1,x,y;i<n;i++){
scanf("%d%d",&x,&y);
add(x,y),add(y,x);
}
dfs(1);
for(int k=1;k<=n;k++){
if(v[k].size()==0){//当前颜色没有节点
printf("%lld\n",1ll*n*(n-1)/2);
continue ;
}
if(v[k].size()==1){//只有一个节点
printf("%lld\n",ans[v[k][0]]);
continue ;
}
int x=0,y=0,z,di=0;m=0;
for(int i=0;i<v[k].size();i++){//找深度最大的节点
a[++m]=v[k][i];
if(d[a[m]]>di)x=a[m],di=d[a[m]];
}
di=0;
for(int i=1;i<=m;i++){//找距离相距最大的点
if(a[i]==x)continue ;
z=lca(x,a[i]);
int dis=d[a[i]]-d[z]+d[x]-d[z];
if(dis>di)y=a[i],di=dis;
}
int flag=true;
for(int i=1;i<=m;i++){//判断是否可以被一条链覆盖
if(a[i]==x||a[i]==y)continue ;
int tmp=lca(x,a[i]);
int tmp1=lca(y,a[i]);
int dis1=d[a[i]]-d[tmp]+d[x]-d[tmp];
int dis2=d[a[i]]-d[tmp1]+d[y]-d[tmp1];
if(dis1+dis2!=di){
flag=false;
break ;
}
}
if(flag){//在一条链上
z=lca(x,y);
if(z==y){
int b=x;
int dep=d[b]-d[y]-1;
for(int i=19;i>=0;i--)//倍增调到深度最浅的点
if(dep&(1<<i))b=fa[b][i];
ans[y]=(n-sz[b]);
}else ans[y]=sz[y];
ans[x]=sz[x];
printf("%lld\n",1ll*ans[y]*ans[x]);
}else printf("0\n");//不在一条链上
}
return 0;
}
```