[DS记录]CFgym101234D Forest Game

· · 个人记录

题目Link

题意 : 在树上随机点分治,问期望复杂度。对10^9+7取模。

------------ 对于如此复杂的问题我们肯定要进行转化,我们建立**点分树**。 每个点的贡献是自己被分治到的次数,也即点分树上的**深度**。 根据期望的线性性,我们分别算出每个点的期望深度,加在一起就是答案。 **深度**又可以分解成 : 每个点作为其祖先的概率和。 原树 $(x,y)$ 路径在点分树中,**有且只有**一个点是 $x,y$ 的公共父亲,那就是`LCA`。更浅的祖先均不属于路径 $(x,y)$。 对于点对 $(x,y)$ ,点 $x$ 在点分树上是 $y$ 的祖先的充要条件是 : 原树上 $(x,y)$ 路径中,在点分树上深度最浅的是 $x$. 充分性易于理解,必要性可以考虑`LCA`的祖先必然不被路径$x,y$包含,最浅的就是`LCA`本身。 否则,$x,y$要么倒置,要么在不同的子树内。 定义$dis(x,y)$为$x,y$之间的**点数**。 $x$在路径$(x,y)$上最浅的点,其概率是$\large\frac{1}{dis(x,y)}$. 这看起来非常简洁,有点不可思议。 这题并没有询问本质不同的点分治方案,而是以随机的排列指导分治。 其实我们对于一条链,我们可以取任意一个点做他们的祖先,后续的方案数都是相同的。 那么答案就是$\sum\limits_{i=1}^n\sum\limits_{j=1}^n\frac{1}{dis(x,y)}

这是分式求和,难以批量维护,只好记录分母为状态。

我们分别统计长度为k的路径条数c[k],答案就是\sum\limits_{i=1}^n\frac{c[k]}{i}

c[k]可以用点分治。

对于某个分治中心,记录到根距离为m的点数F[m]

那么子树的路径拼合的贡献就是H[k]=\sum\limits_{i+j==k}F[i]F[j]

此外可能把来自同一棵子树的路径拼了起来,所以还要减去每个子树内的答案。方法类似。

这题模数并不是NTT模数,只好使用FFT。

注意这题卷积得到的最大数可达n^2=10^{10},而且FFT自带一个n就是n^3=10^{15},应该没问题。

具体实现时,点分治数边数然后+1得到点数更方便。

注意(x,x)是合法的,每个点必然是自己的祖先。

写数组版会快,但是注意清空。中间强制转换int挂了几次……

复杂度O(nlog^2n)

#include<algorithm>
#include<cstring>
#include<cstdio>
#include<vector>
#include<cmath>
#define ll long long
#define mod 1000000007
#define MaxN 100500
using namespace std;
const double Pi=acos(-1);
ll powM(ll a,int t=mod-2)
{
  ll ans=1;
  while(t){
    if (t&1)ans=ans*a%mod;
    a=a*a%mod;
    t>>=1;
  }return ans;
}
struct CP
{
  CP (double xx=0,double yy=0){x=xx,y=yy;}
  double x,y;
  CP operator + (CP const &B) const
  {return CP(x+B.x,y+B.y);}
  CP operator - (CP const &B) const
  {return CP(x-B.x,y-B.y);}
  CP operator * (CP const &B) const
  {return CP(x*B.x-y*B.y,x*B.y+y*B.x);}
}S[270000];
int tr[270000];
void FFT(CP *f,int n,bool flag)
{
  for (int i=0;i<n;i++)
    if (i<tr[i])swap(f[i],f[tr[i]]);
  for(int p=2;p<=n;p<<=1){
    int len=p>>1;
    CP tG(cos(2*Pi/p),sin(2*Pi/p));
    if(!flag)tG.y*=-1;
    for(int k=0;k<n;k+=p){
      CP buf(1,0);
      for(int l=k;l<k+len;l++){
        CP tt=buf*f[len+l];
        f[len+l]=f[l]-tt;
        f[l]=f[l]+tt;
        buf=buf*tG;
      }
    }
  }
}
void sqr(ll *F,int m)
{
  for (int i=0;i<=m;i++)S[i].y=S[i].x=F[i];
  int n=1;for(m<<=1;n<=m+1;n<<=1);
  for(int i=0;i<n;i++)
    tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0);
  FFT(S,n,1);
  for(int i=0;i<n;i++)S[i]=S[i]*S[i];
  FFT(S,n,0);
  for(int i=0;i<=m;i++)
    F[i]=(ll)(S[i].y/n/2+0.49);
  memset(S,0,sizeof(CP)*(n+5));
}
vector<int> g[MaxN];
int tp[MaxN],tn,ms[MaxN],siz[MaxN];
bool vis[MaxN];
void pfs(int u,int fa)
{
  tp[++tn]=u;
  siz[u]=1;ms[u]=0;
  for (int i=0,v;i<g[u].size();i++)
    if ((v=g[u][i])!=fa&&!vis[v]){
      pfs(v,u);
      siz[u]+=siz[v];
      ms[u]=max(ms[u],siz[v]);
    }
}
int getrt(int u)
{
  tn=0;pfs(u,0);
  int rt=0;
  for (int i=1;i<=tn;i++){
    ms[tp[i]]=max(ms[tp[i]],tn-siz[tp[i]]);
    if (ms[tp[i]]<ms[rt])rt=tp[i];
  }return rt;
}
ll t1[MaxN],t2[MaxN];int lim;
void dfs(int u,int fa,int len)
{
  lim=max(lim,len);
  t2[len]++;
  for (int i=0,v;i<g[u].size();i++)
    if ((v=g[u][i])!=fa&&!vis[v])
      dfs(v,u,len+1);
}
ll c[MaxN];
void clac(int u)
{
  int l2=0;t1[0]++;
  for (int i=0;i<g[u].size();i++)
    if (!vis[g[u][i]]){
      lim=0;
      dfs(g[u][i],u,1);
      l2=max(l2,lim);
      for (int i=0;i<=lim;i++)
        t1[i]+=t2[i];
      sqr(t2,lim);
      for (int i=0;i<=lim+lim;i++)
        c[i]-=t2[i];
      memset(t2,0,sizeof(ll)*(lim*2+5));
    }
  sqr(t1,l2);
  for (int i=0;i<=l2+l2;i++)
    c[i]+=t1[i];
  memset(t1,0,sizeof(ll)*(l2*2+5)); 

}
void solve(int u)
{
  clac(u);vis[u]=1;
  for (int i=0,v;i<g[u].size();i++)
    if (!vis[v=g[u][i]])
      solve(getrt(v));
}
int n;
int main()
{
  scanf("%d",&n);
  for (int i=1,fr,to;i<n;i++){
    scanf("%d%d",&fr,&to);
    g[fr].push_back(to);
    g[to].push_back(fr);
  }ms[0]=n;solve(getrt(1));
  ll ans=0;
  for (int i=0;i<n;i++)
    ans=(ans+c[i]%mod*powM(i+1))%mod;
  for (int i=1;i<=n;i++)ans=ans*i%mod;
  printf("%I64d",ans);
}