题解:CF1458F Range Diameter Sum

· · 题解

借鉴了这篇题解,并补充了一些细节与证明。

思路

考虑从大往小枚举区间的左端点,维护每个右端点 i 对应的直径 (u_i,v_i),钦定 u_i<v_i。维护 sum=\sum dis(u_i,v_i),那么答案即为所有左端点对应的 sum 之和。

思考 (u_i,v_i)(u_{i-1},v_{i-1}) 之间的关系,根据直径的经典性质,若 v_i=v_{i-1},则 u_{i}=u_{i-1},否则,v_{i}=iu_iu_{i-1},v_{i-1} 中的一个。

考虑左端点 l+1 变成了 l,设 (u_i,v_i) 变成了 (u'_i,v'_i)。如果要改变,那么 u'_i 一定为 l。从小往大扫 i,那么有以下几种情况:

情况一(u'_i,v'_i)=(u_i,v_i)

以后的 i 也一定是这种情况,直接跳出循环即可。

此时也得到了一个性质,一定是一段前缀的 u'_i 被改为了 l

情况二(u'_i,v'_i)=(l,u_i)

首先有一个性质,就是此时一定满足 (u'_i,v'_i)=(u'_{i-1},v'_{i-1}),证明如下:

  • (u_i,v_i)=(u_{i-1},v_{i-1}),显然两个路径的变化是相同的;

  • 否则,u_i<v_i=i,此时新的 u'_i,v'_i 都比 i 要小,那么 u'_{i-1},v'_{i-1} 也一定可以取到这个最大的路径。

此时把相邻的 v_i 相同的合成一段,若在某一段的开头 i 发生这种情况,那么整个段都会变成 (l,u_i),且这个段会和 i-1 所在的段发生合并。对于 sum 来说,设这一段的长度为 len,其变化量为 len\times(dis(l,u_i)-dis(u_i,v_i))。使用并查集维护这些段即可。

情况三(u'_i,v'_i)=(l,v_i)

此时有一个性质,把相邻的 u_i 相同的合成一段,若一个段内的开头 i 发生了这种情况,则这一整段也会发生这种情况,且路径长度变化量相同。证明如下(为了方便下图中的 u 表示这里的 u_i,而 u' 表示这里的 lv_1 表示 v_iv_2 表示任意一个和 i 在同一段对应的 v):

对于上图这三种情况,显然满足。

注意,我们允许 d_1,d_3,d_4,d_5 中的任何一个为 0,但是 d_2=0 属于第一张图的第三种情况。

对于这种情况,考虑我们知道的条件,即在 u,u',v_1 三个点组成的路径中,(u',v_1) 一定是最长的,在 u,v_1,v_2 三个点组成的路径中,(u,v_2) 一定是最长的。用字母表示这些条件:

联立第一个式子和第三个式子,可以得到 d_3+d_2\le d_1\le d_3-d_2,那么此时 d_2 一定为 0,属于第一张图的第三种情况,这种情况不存在。

注意,我们允许 d_1,d_2,d_4,d_5 中的任何一个为 0,但是 d_3=0 属于第一张图的第三种情况。

对于这种情况,类似上一种列出式子:

联立可得 d_1=d_2,那么 (u',u)(u',v_1) 长度相同,若我们优先把 u'u 组成一条路径,即优先判断情况二,这种情况就一定不会存在。

仍然是维护连续段,但是情况二可能会把一个连续段给分裂。所以可以使用一个栈,每次把整段都被改的删掉,最后加入一个大段,而做完整个过程后最多只有一个段分裂,其他都已经被修改了,这个直接暴力做即可。sum 的变化量也是简单的,每次修改一个长度为 len 的段的 u,那么变化量为 len\times(dis(l,u_i)-dis(u_i,v_i))

还有一个问题,就是如何把当前的 u_i,v_i 取出来。这个是简单的,由于是从小往大扫 i,所以可以维护一个指针表示当前 u 处于哪个段。由于一个 v 的连续段对应的 u 一定相同,我每次都是修改一个 u 的整段或者 v 的整段,所以 i 当前一定是 v 对应的段的开头。

那么就做完了,uv 的连续段只会发生 O(n) 次改变。由于要查询路径长度,时间复杂度 O(n\log n)

代码

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N = 1e6+5;
int n,dfn[N],sz[N],son[N],idx,dep[N],ff[N],top[N];
vector<int> g[N];
void dfs1(int u,int fa)
{
    ff[u] = fa,dep[u] = dep[fa]+1,sz[u] = 1;
    for(auto v:g[u])
    {
        if(v==fa) continue;
        dfs1(v,u);
        sz[u]+=sz[v];
        if(sz[v]>sz[son[u]]) son[u] = v;
    }
}
void dfs2(int u,int tp)
{
    dfn[u] = ++idx,top[u] = tp;
    if(!son[u]) return;
    dfs2(son[u],tp);
    for(auto v:g[u])
    {
        if(v==ff[u]||v==son[u]) continue;
        dfs2(v,v);
    }
}
inline int lca(int x,int y)
{
    while(top[x]^top[y])
    {
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x = ff[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    return x;
}
inline int dis(int x,int y){return dep[x]+dep[y]-2*dep[lca(x,y)];}
int f[N];
int find(int x)
{
    if(f[x]==x) return x;
    return f[x] = find(f[x]); 
}
inline void merge(int x,int y)
{
    x = find(x),y = find(y);
    f[x] = y; 
}
int stk[N],tot,c[N];
signed main()
{
//  freopen(".in","r",stdin);
//  freopen(".out","w",stdout);
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    cin>>n;
    for(int i = 1,u,v;i<n;i++)
        cin>>u>>v,g[u].push_back(v),g[v].push_back(u);
    dfs1(1,0),dfs2(1,1);
    int sum = 0,ans = 0;
    stk[0] = n+1;
    for(int i = n-1;i;i--)
    {
        f[i+1] = i+1;
        sum+=dis(i,i+1);
        int p = i+2;
        while(p<=n)
        {
            int x = p,y = c[tot];
            int d1 = dis(i,x),d2 = dis(i,y),d3 = dis(x,y),mx = max({d1,d2,d3});
            if(d3==mx) break;
            else if(d2==mx)
            {
                stk[tot] = find(p)+1;
                if(stk[tot]==stk[tot-1]) tot--;
                sum+=(d2-d3)*(find(p)-p+1);
                merge(p-1,p);
                p = find(p)+1;
            }
            else
            {
                sum+=(d1-d3)*(stk[tot-1]-stk[tot]);
                p = stk[--tot];
            }
        }
        tot++;
        stk[tot] = i+1,c[tot] = i;
        ans+=sum;
    }
    cout<<ans;
    return 0;
}