[??记录]CF566C Logistical Questions
command_block
·
·
个人记录
题意 : 给出一棵n个节点的无根树,点有点权w[i],边有边权。
定义dis(u,v)为 u,v 之间简单路径上的边权和。
求一个点t使得\sum\limits_{i=1}^nw[i]dis(t,i)^{1.5}最小。 (带权重心)
------------
首先要观察一下性质,由于 $dis(t,i)^{1.5}$ 是凸函数,凸函数的和仍然是凸函数。
所以对于任意一条路径上的$t$ , 代价$\sum\limits_{i=1}^ndis(t,i)^{1.5}$都是凸的。
这就告诉我们,整颗树的代价必然是朝着一个中心严格单调下降的。
我们可以先选取一个点,求出代价,然后比较相邻点的代价,往唯一的代价小的点走(水往低处流)
考虑点分治,每次向更低的联通块找重心,这样可以保证移动不超过$O(\log n)$次。
问题在于可能中心的度数很大,对每个相邻点都计算代价会很慢。
怎么判断某个函数的增减性呢?不难想到求导。
先求出每个子树的导数和,由线性性易求。
具体咋求 : $c\frac{d}{dx}x^{1.5}=c*1.5*x^{0.5}$,这里$x$就是深度,$c$是点权,由于只用于比较,可以把常数省略。
然后查看其他子树的导数和减去某个子树的导数和,如果为负数则表明这里低。
不过,我们在这里强行把不连续函数变成了连续函数,最后一跳可能取在某条边的中间,这时候需要把两个端点都试一试。
当然也有可能恰好在某个点上……
```cpp
#include<algorithm>
#include<cstdio>
#include<vector>
#include<cmath>
#define db double
#define pb push_back
#define MaxN 205000
using namespace std;
inline int read(){
register int X=0;
register char ch=0;
while(ch<48||ch>57)ch=getchar();
while(ch>=48&&ch<=57)X=X*10+(ch^48),ch=getchar();
return X;
}
int n,w[MaxN];
vector<int> g[MaxN],l[MaxN];
int siz[MaxN],t[MaxN],tn,mp[MaxN];
void pfs(int u)
{
siz[t[++tn]=u]=1;
for (int i=0,v;i<g[u].size();i++)
if (!siz[v=g[u][i]]){
pfs(v);
mp[u]=max(mp[u],siz[v]);
siz[u]+=siz[v];
}
}
int getrt(int u)
{
tn=0;pfs(u);int al=siz[u],ret=0;
for (int i=1,v;i<=tn;i++){
v=t[i];
mp[v]=max(mp[v],al-siz[v]);
if (mp[v]<mp[ret])ret=v;
}for (int i=1,v;i<=tn;i++)
mp[t[i]]=siz[t[i]]=0;
return ret;
}
db sum,d[MaxN];
void dfs(int u,int fa,int len)
{
sum+=sqrt(len)*w[u];
for (int i=0,v;i<g[u].size();i++)
if ((v=g[u][i])!=fa)
dfs(v,u,len+l[u][i]);
}
int ans1,ans2;
void solve(int u)
{
siz[u]=1;
db o=0.0;
for (int i=0,v;i<g[u].size();i++){
sum=0;dfs(v=g[u][i],u,l[u][i]);
d[i]=sum;o+=sum;
}
for (int i=0;i<g[u].size();i++)
if (o<d[i]*2.0){
if (siz[u]&&siz[g[u][i]]){
ans1=u;ans2=g[u][i];
return ;
}solve(getrt(g[u][i]));
return ;
}
ans1=ans2=u;
}
void cfs(int u,int fa,int len)
{
sum+=pow(len,1.5)*w[u];
for (int i=0,v;i<g[u].size();i++)
if ((v=g[u][i])!=fa)
cfs(v,u,len+l[u][i]);
}
int main()
{
n=read();mp[0]=n+1;
for (int i=1;i<=n;i++)w[i]=read();
for (int i=1,fr,to,len;i<n;i++){
fr=read();to=read();len=read();
g[fr].pb(to);l[fr].pb(len);
g[to].pb(fr);l[to].pb(len);
}solve(getrt(1));
db sum1=0.0,sum2=0.0;
sum=0;cfs(ans1,0,0);sum1=sum;
sum=0;cfs(ans2,0,0);sum2=sum;
if (sum2<sum1)printf("%d %.10lf",ans2,sum2);
else printf("%d %.10lf",ans1,sum1);
return 0;
}
```