P3629 题解

· · 题解

暴力出奇迹(

蒟蒻因为确实想不出把边权改成 -1 的神奇操作,只能整出下面这个离谱解法。

这个方法整体上看大概是个分类讨论。

K 为 1 时

这个很显然,只能连一条路那就只能连直径。

因为直径是树中最长的链,连它可使得距离的减少量最大。

K 为 2 时

事情开始变得复杂了。

考虑一下第二条路到底会连在哪里。

由感觉可知 第二条路一定是在加了第一条路后形成的环上再搞出一个环。

为了使得路最短,可以猜到这第二条路的两个端点一定得有个在直径上挂着的支链的尾部。

经过咱深刻思考,发现一共有五种情况:

(蓝色是建的第一条路,红色是建的第二条)

情况 1

这种情况下最后的距离就是 2\times(n-1)-\text{直径}+1-\text{支链长度}+1

而且很显然这个支链长度越长越好,找出来最长的那一条就是这种情况的最优解。

情况 2

这种情况下最后的距离也很好求,就是 2*\times(n-1)-\text{直径}+1-\text{红色边构成的环的长度}+2

情况 3

假设新建的第二条边连接的两条支链对应在树上的节点分别为 L,R

$dis$ 为直径上每个点到直径的起点的距离。 那么最后的距离就是 $2*(n-1)-\text{直径}+1-SubLine[L]-SubLine[R]+dis[R]-dis[L]+1$ 的最小值。 问题来了,这个式子和 $L,R$ 都有关系,暴力枚举的话得 $O(n^2)$。 ~~讲个笑话,这题数据过水,你这么写还真能给你过了~~ 不妨假设 $dis[L]<dis[R]$。 然后简单移一下项 $(2*(n-1)-\text{直径}+1)-(SubLine[L]+dis[L])+(-SubLine[R]+dis[R])$。 第一个括号里面是常量,第二个和第三个括号里的值对于每一个 $L$ 或 $R$ 都是确定的。 这很明显可以用单调栈优化 DP 做到 $O(n)$ 嘛。 #### 情况 4 ![](https://i.postimg.cc/ncG6Jq0V/1.png) 这个连法不属于上面的所有三种情况,非常容易漏(~~我找了区区一下午+一晚上~~)。 这个怎么办? 可以在给每个直径上的点求支链的同时解决这种情况。 在求支链的 dfs 我们记录一下每个节点子树中离自己最远的和次深的距离,记作 $MM,Se$。 则总距离就是 $2*(n-1)-\text{直径}+1-MM-Se+1$。 具体代码实现就是下面代码中的 dfs3。 #### 情况 5 就是树是一条链的情况,总距离就是 $2*(n-1)-\text{直径}+1+1$。 这 $5$ 种情况讨论完,就是~~令人崩溃~~的代码时间了。 ## Code ~~经过了你不懈的努力,你终于写完了所有情况~~ 代码非常粪,望见谅。 ```cpp /* 变量意义 dis 2遍dfs求树的直径时候用的那个dis数组 fa 父亲节点(找直径用的 book 标记直径上的点 dis2 帮忙求直径上的支链的长度 Start End 直径的起始点 SubLine 每个点(直径上的)的最长的支链 stk top 单调栈用品 */ #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #include<vector> using namespace std; template<typename T> inline void scan(T& x) { x=0; T f=0; char c=getchar(); while(c<'0') f|=(c=='-'),c=getchar(); while(c>='0') { x=(x<<1)+(x<<3)+(c&15); c=getchar(); } x=f?-x:x; } template<typename T,typename ... args> inline void scan(T& x, args& ... tmp)//快读 { scan(x); scan(tmp...); return; } int dis[100005],n; vector<int>G[100005]; int fa[100005]; bool book[100005]; int dis2[100005]; int Start=1,End=1; int ans1=0x7f7f7f7f,ans2=0x7f7f7f7f; vector<int>PointOnD; int SubLine[100005]; int K; int stk[100005],top; void dfs(int s,int Fa)//求树的直径 { fa[s]=Fa; for (int i=0,Next;i<G[s].size();i++) { Next=G[s][i]; if (Next==Fa||book[Next]) continue; dis[Next]=dis[s]+1; dfs(Next,s); } } int dfs3(int s,int Fa)//求支链+情况4 { bool flag=true; int MaxDeep=0,SeDeep=0; vector<int>Max; for (int i=0,Next,temp;i<G[s].size();i++) { Next=G[s][i]; if (Next==Fa||book[Next]) continue; dis2[Next]=dis2[s]+1; flag=false; temp=dfs3(Next,s); //维护最大值和次大值 if (temp>=MaxDeep) { SeDeep=MaxDeep; MaxDeep=temp; } else if (temp>SeDeep) SeDeep=temp; } if (MaxDeep&&SeDeep)//情况4 ans2=min(ans2,2*(n-1)-dis[End]+1-(MaxDeep-dis2[s])-(SeDeep-dis2[s])+1); if (flag) {return dis2[s];} else {return MaxDeep;} } int main() { scan(n,K); for (int i=1,x,y;i<n;i++) { scan(x,y); G[x].push_back(y); G[y].push_back(x); } dfs(1,0); for (int i=1;i<=n;i++) if (dis[Start]<dis[i]) Start=i; memset(dis,0,sizeof(dis)); memset(fa,0,sizeof(fa)); dfs(Start,0); for (int i=1;i<=n;i++) if (dis[End]<dis[i]) End=i; int pos=End; while (pos)//把直径抠出来 { book[pos]=true; PointOnD.push_back(pos); pos=fa[pos]; } for (int i=0;i<PointOnD.size();i++) book[PointOnD[i]]=true; if (K==1) { printf("%d",2*(n-1)-dis[End]+1); return 0; } else { if (dis[End]==n-1)//情况5 { printf("%d",2*(n-1)-dis[End]+1+1); return 0; } pair<int ,int >MM_All=make_pair(0,0),Se_All=make_pair(0,0); //全局中最长的和次长的支链 int MM,Se; for (int i=0,s;i<PointOnD.size();i++) { MM=Se=0; s=PointOnD[i]; for (int j=0,Next;j<G[s].size();j++) { Next=G[s][j]; if (book[Next]) continue; int temp; dis2[Next]=1; temp=dfs3(Next,s); if (temp>=MM) Se=MM,MM=temp; else if (temp>Se) Se=temp; if (temp>=MM_All.first) { Se_All=MM_All; MM_All=make_pair(temp,i); } else if (temp>Se_All.first) Se_All=make_pair(temp,i); SubLine[s]=max(SubLine[s],temp); } if (MM&&Se)//情况2 ans2=min(ans2,2*(n-1)-dis[End]+1-MM-Se+1); if (MM)//情况1 ans2=min(ans2,2*(n-1)-dis[End]+1-MM+1); } if (MM_All.second==Se_All.second) printf("%d",ans2); else { //情况3 for (int i=0;i<PointOnD.size();i++) { if (top) ans1=min(ans1,-(stk[1]+SubLine[PointOnD[stk[1]]])+(i-SubLine[PointOnD[i]])); while (top&&(-(stk[top]+SubLine[PointOnD[stk[top]]]))>=(-(i+SubLine[PointOnD[i]]))) top--; stk[++top]=i; } ans1+=2*(n-1)-dis[End]+1+1; printf("%d",min(min(ans1,ans2),2*(n-1)-dis[End]+1+1)); } } return 0; } ```