[??记录]CF566C Logistical Questions

· · 个人记录

题意 : 给出一棵n个节点的无根树,点有点权w[i],边有边权。

定义dis(u,v)u,v 之间简单路径上的边权和。

求一个点t使得\sum\limits_{i=1}^nw[i]dis(t,i)^{1.5}最小。 (带权重心)

------------ 首先要观察一下性质,由于 $dis(t,i)^{1.5}$ 是凸函数,凸函数的和仍然是凸函数。 所以对于任意一条路径上的$t$ , 代价$\sum\limits_{i=1}^ndis(t,i)^{1.5}$都是凸的。 这就告诉我们,整颗树的代价必然是朝着一个中心严格单调下降的。 我们可以先选取一个点,求出代价,然后比较相邻点的代价,往唯一的代价小的点走(水往低处流) 考虑点分治,每次向更低的联通块找重心,这样可以保证移动不超过$O(\log n)$次。 问题在于可能中心的度数很大,对每个相邻点都计算代价会很慢。 怎么判断某个函数的增减性呢?不难想到求导。 先求出每个子树的导数和,由线性性易求。 具体咋求 : $c\frac{d}{dx}x^{1.5}=c*1.5*x^{0.5}$,这里$x$就是深度,$c$是点权,由于只用于比较,可以把常数省略。 然后查看其他子树的导数和减去某个子树的导数和,如果为负数则表明这里低。 不过,我们在这里强行把不连续函数变成了连续函数,最后一跳可能取在某条边的中间,这时候需要把两个端点都试一试。 当然也有可能恰好在某个点上…… ```cpp #include<algorithm> #include<cstdio> #include<vector> #include<cmath> #define db double #define pb push_back #define MaxN 205000 using namespace std; inline int read(){ register int X=0; register char ch=0; while(ch<48||ch>57)ch=getchar(); while(ch>=48&&ch<=57)X=X*10+(ch^48),ch=getchar(); return X; } int n,w[MaxN]; vector<int> g[MaxN],l[MaxN]; int siz[MaxN],t[MaxN],tn,mp[MaxN]; void pfs(int u) { siz[t[++tn]=u]=1; for (int i=0,v;i<g[u].size();i++) if (!siz[v=g[u][i]]){ pfs(v); mp[u]=max(mp[u],siz[v]); siz[u]+=siz[v]; } } int getrt(int u) { tn=0;pfs(u);int al=siz[u],ret=0; for (int i=1,v;i<=tn;i++){ v=t[i]; mp[v]=max(mp[v],al-siz[v]); if (mp[v]<mp[ret])ret=v; }for (int i=1,v;i<=tn;i++) mp[t[i]]=siz[t[i]]=0; return ret; } db sum,d[MaxN]; void dfs(int u,int fa,int len) { sum+=sqrt(len)*w[u]; for (int i=0,v;i<g[u].size();i++) if ((v=g[u][i])!=fa) dfs(v,u,len+l[u][i]); } int ans1,ans2; void solve(int u) { siz[u]=1; db o=0.0; for (int i=0,v;i<g[u].size();i++){ sum=0;dfs(v=g[u][i],u,l[u][i]); d[i]=sum;o+=sum; } for (int i=0;i<g[u].size();i++) if (o<d[i]*2.0){ if (siz[u]&&siz[g[u][i]]){ ans1=u;ans2=g[u][i]; return ; }solve(getrt(g[u][i])); return ; } ans1=ans2=u; } void cfs(int u,int fa,int len) { sum+=pow(len,1.5)*w[u]; for (int i=0,v;i<g[u].size();i++) if ((v=g[u][i])!=fa) cfs(v,u,len+l[u][i]); } int main() { n=read();mp[0]=n+1; for (int i=1;i<=n;i++)w[i]=read(); for (int i=1,fr,to,len;i<n;i++){ fr=read();to=read();len=read(); g[fr].pb(to);l[fr].pb(len); g[to].pb(fr);l[to].pb(len); }solve(getrt(1)); db sum1=0.0,sum2=0.0; sum=0;cfs(ans1,0,0);sum1=sum; sum=0;cfs(ans2,0,0);sum2=sum; if (sum2<sum1)printf("%d %.10lf",ans2,sum2); else printf("%d %.10lf",ans1,sum1); return 0; } ```