[DP记录]AT2268 [AGC008F] Black Radius

· · 个人记录

题意 : 给出一颗 n 个节点的树 T,以及关键点集合 S

定义树上邻域 f(u,d)=\{\text{距离u不超过d的点集}\}

P=\{f(u,d)|u\in S\} ,求 |P|。即本质不同的关键点邻域个数。

------------ 牛逼题。 - $S=T

先统计所有 f(u,d)≠T 的贡献,最后加 1

这样,对于同一个 u ,有 d_1≠d_2\Leftrightarrow f(u,d_1)≠f(u,d_2)

P 中有若干个 f(u,d) 同构,考虑在 d 最小的 f 处将其统计。(关键思想)

先观察怎样的 f(u,d_1),f(v,d_2) 会相同,如下图 :

对于同一个关键点 u ,只有若干个 d 是有贡献的。考虑某个 f(u,d) 能贡献的条件。

首先我们要求 f(u,d)≠T ,设 md[u]u 到最远点的距离,则有 d<md[u]

此外,考虑同构。以 u 为根,考虑另一个关键点点 v ,设 k=dis(u,v)

显然,除了 f(v,d-k) 之外,其余的 f(v,\_)≠f(u,d)

f(v,d-k)=f(u,d) ,则 \Leftrightarrow 两者都能覆盖除了 v 上方的所有部分。

(如图,蓝色点为选择的 v ,红色部分为 f(v,d-k) 需要覆盖的部分)

td[u][v] 为以 u 为根时 v 子树外的点到 v 的最远距离。

则有 td[u][v]\leq d-ktd[u][v]+dis(u,v)\leq d。排除掉这些 d 即可。

这也说明了,对于某个 u ,能贡献的 d 是一个前缀。

现在,要对每个关键点 u 求出 \min\limits_{v≠u}td[u][v]+dis(u,v)

不难发现,这里的 v 可以只考虑相邻节点。则变为 \min\limits_{u\leftrightarrow v}td[u][v]+1

而所有 td 可以使用 \text{up and down} 树形 \rm DP 求出。

f(u,d) 中,若 u 为关键点,d 的下界必然是 0 ,但对于非关键点则是其他数。

考虑对于非关键点 u ,使得 f(u,d) 有贡献的最小的 d

u 为根,考虑某个关键点 u。若能覆盖 u 所在的整个分支,就能“夺过” u 的邻域。

不难发现,当 d 更大时,上述关系仍然可以成立。

于是,下界即为 : 覆盖某个含关键点的分支所需的最小 d

也可以使用 \text{up and down} 树形 \rm DP 求出。。

知道了各个 d 的上下界,容易得到答案。

#include<algorithm>
#include<cstdio>
#include<vector>
#define pb push_back
#define MaxN 200500
using namespace std;
const int INF=1000000000;
vector<int> g[MaxN];
char fl[MaxN];
int len[MaxN],fa[MaxN],siz[MaxN];
void dfs1(int u,int _f)
{
  fa[u]=_f;
  siz[u]=fl[u];
  for (int i=0,v;i<g[u].size();i++)
    if ((v=g[u][i])!=_f){
      dfs1(v,u);
      len[u]=max(len[u],len[v]+1);
      siz[u]+=siz[v];
    }
}
int flen[MaxN];
void dfs2(int u,int l)
{
  flen[u]=l;
  int p1=-1,x1=-1,x2=-1;
  for (int i=0;i<g[u].size();i++)
    if (g[u][i]!=fa[u])
      if (len[g[u][i]]>x1)
        {x2=x1;x1=len[g[u][p1=i]];}
      else x2=max(x2,len[g[u][i]]);
  for (int i=0;i<g[u].size();i++)
    if (g[u][i]!=fa[u])
      dfs2(g[u][i],max((i==p1 ? x2 : x1)+2,l+1));
}
int n;
int main()
{
  scanf("%d",&n);
  for (int i=1,u,v;i<n;i++){
    scanf("%d%d",&u,&v);
    g[u].pb(v);g[v].pb(u);
  }scanf("%s",fl+1);
  for (int i=1;i<=n;i++)fl[i]-='0';
  dfs1(1,0);dfs2(1,0);
  long long ans=0;
  for (int u=1;u<=n;u++){
    int x1=0,x2=0;
    for (int i=0;i<g[u].size();i++)
      if (g[u][i]!=fa[u])
        if (len[g[u][i]]+1>x1)
          {x2=x1;x1=len[g[u][i]]+1;}
        else x2=max(x2,len[g[u][i]]+1);
    if (flen[u]){
      if (flen[u]>x1)
        {x2=x1;x1=flen[u];}
      else x2=max(x2,flen[u]);
    }
    int tr=max(0,min(x2+1,max(len[u],flen[u])-1)),tl=n+1;
    if (fl[u])tl=0;
    for (int i=0;i<g[u].size();i++)
      if (g[u][i]!=fa[u]&&siz[g[u][i]])
        tl=min(tl,len[g[u][i]]+1);
    if (fa[u]&&siz[1]-siz[u]>0)
      tl=min(tl,flen[u]);
    ans+=(tl<=tr ? tr-tl+1 : 0);
  }printf("%lld",ans+1);
  return 0;
}