题解:CF1458F Range Diameter Sum
借鉴了这篇题解,并补充了一些细节与证明。
思路
考虑从大往小枚举区间的左端点,维护每个右端点
思考
考虑左端点
情况一:
以后的
此时也得到了一个性质,一定是一段前缀的
情况二:
首先有一个性质,就是此时一定满足
若
(u_i,v_i)=(u_{i-1},v_{i-1}) ,显然两个路径的变化是相同的;否则,
u_i<v_i=i ,此时新的u'_i,v'_i 都比i 要小,那么u'_{i-1},v'_{i-1} 也一定可以取到这个最大的路径。
此时把相邻的
情况三:
此时有一个性质,把相邻的
对于上图这三种情况,显然满足。
注意,我们允许
d_1,d_3,d_4,d_5 中的任何一个为0 ,但是d_2=0 属于第一张图的第三种情况。对于这种情况,考虑我们知道的条件,即在
u,u',v_1 三个点组成的路径中,(u',v_1) 一定是最长的,在u,v_1,v_2 三个点组成的路径中,(u,v_2) 一定是最长的。用字母表示这些条件:
联立第一个式子和第三个式子,可以得到
d_3+d_2\le d_1\le d_3-d_2 ,那么此时d_2 一定为0 ,属于第一张图的第三种情况,这种情况不存在。注意,我们允许
d_1,d_2,d_4,d_5 中的任何一个为0 ,但是d_3=0 属于第一张图的第三种情况。对于这种情况,类似上一种列出式子:
联立可得
d_1=d_2 ,那么(u',u) 和(u',v_1) 长度相同,若我们优先把u' 和u 组成一条路径,即优先判断情况二,这种情况就一定不会存在。
仍然是维护连续段,但是情况二可能会把一个连续段给分裂。所以可以使用一个栈,每次把整段都被改的删掉,最后加入一个大段,而做完整个过程后最多只有一个段分裂,其他都已经被修改了,这个直接暴力做即可。
还有一个问题,就是如何把当前的
那么就做完了,
代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N = 1e6+5;
int n,dfn[N],sz[N],son[N],idx,dep[N],ff[N],top[N];
vector<int> g[N];
void dfs1(int u,int fa)
{
ff[u] = fa,dep[u] = dep[fa]+1,sz[u] = 1;
for(auto v:g[u])
{
if(v==fa) continue;
dfs1(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]]) son[u] = v;
}
}
void dfs2(int u,int tp)
{
dfn[u] = ++idx,top[u] = tp;
if(!son[u]) return;
dfs2(son[u],tp);
for(auto v:g[u])
{
if(v==ff[u]||v==son[u]) continue;
dfs2(v,v);
}
}
inline int lca(int x,int y)
{
while(top[x]^top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x = ff[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
return x;
}
inline int dis(int x,int y){return dep[x]+dep[y]-2*dep[lca(x,y)];}
int f[N];
int find(int x)
{
if(f[x]==x) return x;
return f[x] = find(f[x]);
}
inline void merge(int x,int y)
{
x = find(x),y = find(y);
f[x] = y;
}
int stk[N],tot,c[N];
signed main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
cin>>n;
for(int i = 1,u,v;i<n;i++)
cin>>u>>v,g[u].push_back(v),g[v].push_back(u);
dfs1(1,0),dfs2(1,1);
int sum = 0,ans = 0;
stk[0] = n+1;
for(int i = n-1;i;i--)
{
f[i+1] = i+1;
sum+=dis(i,i+1);
int p = i+2;
while(p<=n)
{
int x = p,y = c[tot];
int d1 = dis(i,x),d2 = dis(i,y),d3 = dis(x,y),mx = max({d1,d2,d3});
if(d3==mx) break;
else if(d2==mx)
{
stk[tot] = find(p)+1;
if(stk[tot]==stk[tot-1]) tot--;
sum+=(d2-d3)*(find(p)-p+1);
merge(p-1,p);
p = find(p)+1;
}
else
{
sum+=(d1-d3)*(stk[tot-1]-stk[tot]);
p = stk[--tot];
}
}
tot++;
stk[tot] = i+1,c[tot] = i;
ans+=sum;
}
cout<<ans;
return 0;
}