题解 P3806 【【模板】点分治1】
淀粉质灰常好吃,请放心食用
点分治:处理树上路径的神器,举个栗子(真好吃):
给定一棵树和一个整数 kk ,求树上边数等于 kk 的路径有多少条
此题是比较水的点分治题了,蒟蒻也只会这个辣,于是乎就来水一发笔记。
对于此题来说,首先O(n^3)的暴力肯定GG,于是我们想到,对于以rt为根的子树,一条路径只有两种情况:1.经过rt;2.不经过rt(废话),就是路径被子树所包含。
但是细想:其实对于情况2,我们可以把问题扔给rt的子树,让它变成在以rt'为根的子树中经过rt',也就直接把情况2转化成了情况1。
如图,Root为根的子树中存在答案(蓝色实边路径),可以看成以 Root2为根的两棵子树存在答案,所以只用处理情况2就行了,可以用分治的方法,这应该是点分治的基本原理——来自dalao守望(洛谷日报)
- 找重心
这个不多解释,一个树形dp,和dsu on tree相同,以重心为根可以最大幅度的减小深度,降低时间复杂度。
code:
void get_root(int u,int fa)
{
dp[u]=0;size[u]=1;//初始化
for(int i=head[u];i;i=g[i].next)
{
int v=g[i].to;
if(v==fa||vis[v])continue;
get_root(v,u);
dp[u]=max(dp[u],size[v]);//
size[u]+=size[v];
}
dp[u]=max(dp[u],S-size[u]);//注意别把父节点作为子节点的情况漏掉了
//S为以rt为根的树的size
if(dp[u]<dp[rt])rt=u;//更新rt,(dp[rt]=inf)
}
- 点分治:从重心开始,往子节点递归,每个点调用一次solve函数,达到分治的目的。
code:
void solve(int u)
{
vis[u]=judge[0]=1;//初始化
dfs(u);//后面会讲到
for(int i=head[u];i;i=g[i].next)
{
int v=g[i].to;
if(vis[v])continue;
S=size[v];dp[rt=0]=inf;//更新S,重新找子树的重心
get_root(v,0);
solve(rt);
}
}
- 找路径长度
dist表示点u到rt的路径长度,代码很好写
code:
void get_dis(int u,int fa)
{
rem[++rem[0]]=dist[u];//rem后面会讲
for(int i=head[u];i;i=g[i].next)
{
int v=g[i].to,d=g[i].dis;
if(v==fa||vis[v])continue;
dist[v]=dist[u]+d;//更新v的dist
get_dis(v,u);
}
}
- dfs(核心代码)判断以经过rt的路径是否有长度为k的
rem数组统计搜索到了路径长度,rem[0]记录的是搜到的路径个数。
judge是桶,判断是否存在长度为i的路径若存在judge[i]=1
对于每个dfs到的点u,先以u为rt找到子树所有点的dist,并用rem记录下来。和query(记录题干中k的数组)进行判断,结果记录到test中。
具体看code:
void dfs(int u)
{
int p=0;
for(int i=head[u];i;i=g[i].next)
{
int v=g[i].to,d=g[i].dis;
if(vis[v])continue;
rem[0]=0;dist[v]=d;//先初始化
get_dis(v,u);//找子树所有点的dist
for(int j=rem[0];j;j--)
for(int k=1;k<=m;k++)
if(query[k]>=rem[j])//因为要找的路径必须经过rt
test[k]|=judge[query[k]-rem[j]];//如果发现k减去当前子树的某一路径长度值
//有另一子树中某一长度相加正好等于k,因为judge
//不是在当前子树被赋值的,而是记录了之前所搜到
//的子树的信息,所以并不会回到当前子树,一定会
//跨过rt节点
for(int j=rem[0];j;j--)
q[++p]=rem[j],judge[rem[j]]=1;//保证judge存储了之前搜索子树的全部长度值
}
//memset(judge,0,sizeof(judge));
for(int i=1;i<=p;i++)
judge[q[i]]=0;//搜完了一个rt,所有信息都要清空,注意memset会TLEQAQ
}
好了,各组件介绍完毕
然后我们就可以非常快乐地调试了……
最后放上全部code完美✿✿ヽ(°▽°)ノ✿
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#define go(i,a,b) for(int i=a;i<=b;i++)
using namespace std;
const int inf=0x7fffffff;
const int N=10010;
struct node
{
int next,to,dis;
}g[N*2];
int head[N],cnt;
int n,m,S;
int dp[N],size[N],rem[N],dist[N];
int rt;
int query[110],test[110],q[N];
bool judge[10000010],vis[N];
inline int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
inline void addedge(int u,int v,int dis)
{
g[++cnt].next=head[u];
g[cnt].to=v;
g[cnt].dis=dis;
head[u]=cnt;
}
void get_root(int u,int fa)
{
dp[u]=0;size[u]=1;//初始化
for(int i=head[u];i;i=g[i].next)
{
int v=g[i].to;
if(v==fa||vis[v])continue;
get_root(v,u);
dp[u]=max(dp[u],size[v]);//
size[u]+=size[v];
}
dp[u]=max(dp[u],S-size[u]);//注意别把父节点作为子节点的情况漏掉了
//S为以rt为根的树的size
if(dp[u]<dp[rt])rt=u;//更新rt,(dp[rt]=inf)
}
void get_dis(int u,int fa)
{
rem[++rem[0]]=dist[u];//rem后面会讲
for(int i=head[u];i;i=g[i].next)
{
int v=g[i].to,d=g[i].dis;
if(v==fa||vis[v])continue;
dist[v]=dist[u]+d;//更新v的dist
get_dis(v,u);
}
}
void dfs(int u)
{
int p=0;
for(int i=head[u];i;i=g[i].next)
{
int v=g[i].to,d=g[i].dis;
if(vis[v])continue;
rem[0]=0;dist[v]=d;//先初始化
get_dis(v,u);//找子树所有点的dist
for(int j=rem[0];j;j--)
for(int k=1;k<=m;k++)
if(query[k]>=rem[j])//因为要找的路径必须经过rt
test[k]|=judge[query[k]-rem[j]];//如果发现k减去当前子树的某一路径长度值
//有另一子树中某一长度相加正好等于k,因为judge
//不是在当前子树被赋值的,而是记录了之前所搜到
//的子树的信息,所以并不会回到当前子树,一定会
//跨过rt节点
for(int j=rem[0];j;j--)
q[++p]=rem[j],judge[rem[j]]=1;//保证judge存储了之前搜索子树的全部长度值
}
//memset(judge,0,sizeof(judge));
for(int i=1;i<=p;i++)
judge[q[i]]=0;//搜完了一个rt,所有信息都要清空,注意memset会TLEQAQ
}
void solve(int u)
{
vis[u]=judge[0]=1;//初始化
dfs(u);//后面会讲到
for(int i=head[u];i;i=g[i].next)
{
int v=g[i].to;
if(vis[v])continue;
S=size[v];dp[rt=0]=inf;//更新S,重新找子树的重心
get_root(v,0);
solve(rt);
}
}
int main()
{
n=read(),m=read();
S=n;
for(int i=1;i<n;i++)
{
int x=read(),y=read(),z=read();
addedge(x,y,z);addedge(y,x,z);
}
dp[rt]=inf;
get_root(1,0);
for(int i=1;i<=m;i++)
query[i]=read();
solve(rt);
for(int i=1;i<=m;i++)
if(test[i])printf("AYE\n");
else printf("NAY\n");
}