[DP记录]AT2268 [AGC008F] Black Radius
command_block
2021-01-13 09:39:46
**题意** : 给出一颗 $n$ 个节点的树 $T$,以及关键点集合 $S$。
定义树上邻域 $f(u,d)=\{\text{距离u不超过d的点集}\}$。
令 $P=\{f(u,d)|u\in S\}$ ,求 $|P|$。即本质不同的关键点邻域个数。
$n\leq 2\times 10^5$ ,时限$\texttt{2s}$。
------------
牛逼题。
- $S=T$
先统计所有 $f(u,d)≠T$ 的贡献,最后加 $1$。
这样,对于同一个 $u$ ,有 $d_1≠d_2\Leftrightarrow f(u,d_1)≠f(u,d_2)$。
若 $P$ 中有若干个 $f(u,d)$ 同构,考虑在 $d$ 最小的 $f$ 处将其统计。(关键思想)
先观察怎样的 $f(u,d_1),f(v,d_2)$ 会相同,如下图 :
![](https://cdn.luogu.com.cn/upload/image_hosting/v7djc2xa.png)
对于同一个关键点 $u$ ,只有若干个 $d$ 是有贡献的。考虑某个 $f(u,d)$ 能贡献的条件。
首先我们要求 $f(u,d)≠T$ ,设 $md[u]$ 为 $u$ 到最远点的距离,则有 $d<md[u]$。
此外,考虑同构。以 $u$ 为根,考虑另一个关键点点 $v$ ,设 $k=dis(u,v)$。
显然,除了 $f(v,d-k)$ 之外,其余的 $f(v,\_)≠f(u,d)$。
若 $f(v,d-k)=f(u,d)$ ,则 $\Leftrightarrow$ 两者都能覆盖除了 $v$ 上方的所有部分。
![](https://cdn.luogu.com.cn/upload/image_hosting/nx2ck4s6.png)
(如图,蓝色点为选择的 $v$ ,红色部分为 $f(v,d-k)$ 需要覆盖的部分)
设 $td[u][v]$ 为以 $u$ 为根时 $v$ 子树外的点到 $v$ 的最远距离。
则有 $td[u][v]\leq d-k$ 即 $td[u][v]+dis(u,v)\leq d$。排除掉这些 $d$ 即可。
这也说明了,对于某个 $u$ ,能贡献的 $d$ 是一个前缀。
现在,要对每个关键点 $u$ 求出 $\min\limits_{v≠u}td[u][v]+dis(u,v)$。
不难发现,这里的 $v$ 可以只考虑相邻节点。则变为 $\min\limits_{u\leftrightarrow v}td[u][v]+1$
而所有 $td$ 可以使用 $\text{up and down}$ 树形 $\rm DP$ 求出。
- $S\neq T$
在 $f(u,d)$ 中,若 $u$ 为关键点,$d$ 的下界必然是 $0$ ,但对于非关键点则是其他数。
考虑对于非关键点 $u$ ,使得 $f(u,d)$ 有贡献的最小的 $d$。
以 $u$ 为根,考虑某个关键点 $u$。若能覆盖 $u$ 所在的整个分支,就能“夺过” $u$ 的邻域。
不难发现,当 $d$ 更大时,上述关系仍然可以成立。
于是,下界即为 : 覆盖某个含关键点的分支所需的最小 $d$。
也可以使用 $\text{up and down}$ 树形 $\rm DP$ 求出。。
知道了各个 $d$ 的上下界,容易得到答案。
```cpp
#include<algorithm>
#include<cstdio>
#include<vector>
#define pb push_back
#define MaxN 200500
using namespace std;
const int INF=1000000000;
vector<int> g[MaxN];
char fl[MaxN];
int len[MaxN],fa[MaxN],siz[MaxN];
void dfs1(int u,int _f)
{
fa[u]=_f;
siz[u]=fl[u];
for (int i=0,v;i<g[u].size();i++)
if ((v=g[u][i])!=_f){
dfs1(v,u);
len[u]=max(len[u],len[v]+1);
siz[u]+=siz[v];
}
}
int flen[MaxN];
void dfs2(int u,int l)
{
flen[u]=l;
int p1=-1,x1=-1,x2=-1;
for (int i=0;i<g[u].size();i++)
if (g[u][i]!=fa[u])
if (len[g[u][i]]>x1)
{x2=x1;x1=len[g[u][p1=i]];}
else x2=max(x2,len[g[u][i]]);
for (int i=0;i<g[u].size();i++)
if (g[u][i]!=fa[u])
dfs2(g[u][i],max((i==p1 ? x2 : x1)+2,l+1));
}
int n;
int main()
{
scanf("%d",&n);
for (int i=1,u,v;i<n;i++){
scanf("%d%d",&u,&v);
g[u].pb(v);g[v].pb(u);
}scanf("%s",fl+1);
for (int i=1;i<=n;i++)fl[i]-='0';
dfs1(1,0);dfs2(1,0);
long long ans=0;
for (int u=1;u<=n;u++){
int x1=0,x2=0;
for (int i=0;i<g[u].size();i++)
if (g[u][i]!=fa[u])
if (len[g[u][i]]+1>x1)
{x2=x1;x1=len[g[u][i]]+1;}
else x2=max(x2,len[g[u][i]]+1);
if (flen[u]){
if (flen[u]>x1)
{x2=x1;x1=flen[u];}
else x2=max(x2,flen[u]);
}
int tr=max(0,min(x2+1,max(len[u],flen[u])-1)),tl=n+1;
if (fl[u])tl=0;
for (int i=0;i<g[u].size();i++)
if (g[u][i]!=fa[u]&&siz[g[u][i]])
tl=min(tl,len[g[u][i]]+1);
if (fa[u]&&siz[1]-siz[u]>0)
tl=min(tl,flen[u]);
ans+=(tl<=tr ? tr-tl+1 : 0);
}printf("%lld",ans+1);
return 0;
}
```