题解 P3806 【【模板】点分治1】
hicc0305
2018-10-31 21:04:09
里脊说点分治可以等后面再学,可我觉得好像理解一下不用多少时间。。还是蛮好理解的。
### 用处
点分治主要用来处理树上的一些求距离为k的点对的操作。比如:树上距离为k的点对是否存在,树上距离小于等于k的点对有多少个。。。
### 解法
首先,点分治肯定要分治。。。
把问题转化为,在当前子树中:
1.过根的点对 2.在根的子树中的点对
对于情况一我们立即处理,而对于情况二我们递归即可。
然后,学过平衡树的各位巨佬们应该都能理解,让树更平衡能更快。怎么样让树更平衡呢?就是让树的重心做根。重心的定义就是以重心为根的所有子树中最大的子树的大小比其他点做根的最大子树大小都要小。
求重心的一遍大法师(是的,一遍就可以了):
```
void getroot(int u,int fa)
{
siz[u]=1;f[u]=0;
for(int i=head[u];i!=-1;i=nxt[i])
{
int v=to[i];
if(v==fa || vis[v]) continue;
getroot(v,u);
siz[u]+=siz[v];
f[u]=max(f[u],siz[v]);
}
f[u]=max(f[u],sum-siz[u]);//以u为根的所有子树大小就是以1为根时所有的子树以及除以u为根的这颗子树的所有点
if(f[u]<f[rt]) rt=u;
}
```
具体代码怎么写哪?我们先找出整棵树的重心,以重心为根大法师,再往下递归,找出子树的重心,再以子树的重心为子树的根往下扫,就是:
```
void solve(int u)
{
Add(u,0,1);vis[u]=1;//对于这道模板题,这里加的是u中所有点选取两个点和u的距离和,不能保证合法,
//也就是选取的两个点在同一颗子树中的话,他们的距离被处理成了u->v1+u->v2,而不是lca->v1+lca->v2,
//所以下面的那个add要把这些不合法的情况减掉,不能理解的话,代码放出来结合一下就能知道了
for(int i=head[u];i!=-1;i=nxt[i])
{
int v=to[i];
if(vis[v]) continue;
Add(v,val[i],-1);
sum=siz[v];rt=0;
getroot(v,u);//找出子树的重心
solve(rt);
}
}
```
那么处理距离,上面一堆注释也有提及了,我们先处理出u的子树中,所有点离u的距离,然后选取两个点,cnt[u->v1+u->v2]++,cnt记录的就是距离为k的点对有多少个。显然,这样子对于两个点在不同的子树中的情况是成立的,另外不合法的我们可以通过减掉子树中的所有情况处理。
那么全部代码:
```cpp
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
int n,m,cnt=0,rt,sum;
int head[10100],nxt[20100],to[20100],val[20100],ans[100100];
int f[100100],siz[10100],dep[100100],tmp[100100],vis[100100];
void addedge(int x,int y,int z)
{
cnt++;
nxt[cnt]=head[x];
head[x]=cnt;
to[cnt]=y;
val[cnt]=z;
}
void getroot(int u,int fa)
{
siz[u]=1;f[u]=0;
for(int i=head[u];i!=-1;i=nxt[i])
{
int v=to[i];
if(v==fa || vis[v]) continue;
getroot(v,u);
siz[u]+=siz[v];
f[u]=max(f[u],siz[v]);
}
f[u]=max(f[u],sum-siz[u]);
if(f[u]<f[rt]) rt=u;
}
void getdep(int u,int fa)
{
tmp[++cnt]=dep[u];//处理子树中离u的距离
for(int i=head[u];i!=-1;i=nxt[i])
{
int v=to[i];
if(v==fa || vis[v]) continue;
dep[v]=dep[u]+val[i],getdep(v,u);
}
}
void Add(int u,int s,int p)
{
dep[u]=s;cnt=0;
getdep(u,0);
for(int i=1;i<=cnt;i++)
for(int j=i+1;j<=cnt;j++)
ans[tmp[i]+tmp[j]]+=p;//对点对距离进行加减
}
void solve(int u)
{
Add(u,0,1);vis[u]=1;
for(int i=head[u];i!=-1;i=nxt[i])
{
int v=to[i];
if(vis[v]) continue;
Add(v,val[i],-1);//子树中的所有点对距离+val[i]就是不合法的情况,减掉即可,最终剩下的仅仅是过当前点的链
sum=siz[v];rt=0;
getroot(v,u);//递归处理子树
solve(rt);
}
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++)
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
addedge(x,y,z);
addedge(y,x,z);
}
sum=f[0]=n;getroot(1,0);
solve(rt);
while(m--)
{
int k;
scanf("%d",&k);
if(ans[k]) printf("AYE\n");
else printf("NAY\n");
}
return 0;
}
```