「两个小时的思考」高级代码

· · 个人记录

今日,cxl 在模拟赛上,通过两个小时的深入思考,写出了一份代码。

猜猜这份代码所求 \texttt{coef} 数组的含义:

int fac[M],ifac[M],inv[M];

IL void init(reg int n)
{
    fac[0]=ifac[0]=inv[1]=1;
    for(reg int i=2;i<=n;++i)inv[i]=Mul(mod-mod/i,inv[mod%i]);
    for(reg int i=1;i<=n;++i)fac[i]=Mul(fac[i-1],i),ifac[i]=Mul(ifac[i-1],inv[i]);
}

IL int C(reg int n,reg int m){return n==m?1:Mul(fac[n],Mul(ifac[m],ifac[n-m]));}

int n,m;
std::vector<int>G[N];

IL void add(reg int u,reg int v){G[u].push_back(v),G[v].push_back(u);}

int f0[N][N],f1[N][N],f2[N][N],sz[N]; // x^0, x^1, x^2

IL void slack(reg int &a0,reg int &a1,reg int &a2,reg std::vector<int>b,reg std::vector<int>c)
{
    reg std::vector<int>a; a.resize(3); assert(!a[0]&&!a[1]&&!a[2]);
    for(reg int i=0,j;i<3;++i)for(j=0;i+j<3;++j)Pls(a[i+j],Mul(Mul(b[i],c[j]),C(i+j,i)));
    Pls(a0,a[0]),Pls(a1,a[1]),Pls(a2,a[2]);
}

int coef[N];

void dfs(reg int u,reg int fa=0)
{
    f0[u][1]=f1[u][1]=f2[u][1]=sz[u]=1;
    static int g0[N],g1[N],g2[N];
    for(reg auto v:G[u])if(v!=fa)
    {
        dfs(v,u);
        for(reg int i=1;i<=sz[u]+sz[v];++i)g0[i]=g1[i]=g2[i]=0;
        for(reg int i=1,j;i<=sz[u];++i)
        {
            for(j=1;j<=sz[v];++j)slack(g0[i+j],g1[i+j],g2[i+j],{f0[u][i],f1[u][i],f2[u][i]},{f0[v][j],f1[v][j],f2[v][j]});
            slack(g0[i],g1[i],g2[i],{f0[u][i],f1[u][i],f2[u][i]},{1,0,0});
            slack(g0[i+1],g1[i+1],g2[i+1],{f0[u][i],f1[u][i],f2[u][i]},{mod-1,0,0});
        }
        sz[u]+=sz[v];
        for(reg int i=1;i<=sz[u];++i)f0[u][i]=g0[i],f1[u][i]=g1[i],f2[u][i]=g2[i];
    }
    if(u>1)
    {
        for(reg int i=1;i<=sz[u]+1;++i)g0[i]=g1[i]=g2[i]=0;
        for(reg int i=1;i<=sz[u];++i)
        {
            slack(g0[i],g1[i],g2[i],{f0[u][i],f1[u][i],f2[u][i]},{1,0,0});
            slack(g0[i+1],g1[i+1],g2[i+1],{f0[u][i],f1[u][i],f2[u][i]},{mod-1,0,0});            
        }
    }else for(reg int i=1;i<=sz[u]+1;++i)g2[i]=f2[u][i];
    for(reg int i=1;i<=sz[u]+1;++i)Pls(coef[i],g2[i]);
}

main()
{
    n=read(),m=read()-1,init(1e5);
    for(reg int i=n;--i;)add(read(),read());
    dfs(1);
    for(reg int i=1;i<=n;++i)printf("%d%c",coef[i]," \n"[i==n]);
}
\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \

答:O(n^2) 对每个 i,求树上长度为 i 路径个数。

小 E:机房史实,O(n^2) 求树上长度为 i 路径数,想了两个小时!