题解:P5351 Ruri Loves Maschera

· · 题解

分析

树上所有路径问题,考虑点分治。

点分治的核心思想是先找到连通块的重心 rt,算经过 rt 的所有合法路径的贡献,再删掉 rt 递归处理它的子树。

设当前重心为 rt。对于树上的任意一个点 u,记 mx_u 代表从 rtu 路径上的最大边权,k_u 代表从 rtu 的路径长度(边数)

那么对于两个点 uv,它们之间经过 rt 的路径魔力值为 \max(mx_u,mx_v),长度为 k_u+k_v。则现在要统计所有满足 L\le k_u+k_v\le R 的点对 (u,v)\max(mx_u,mx_v) 之和。

怎么算 mx_u 的贡献呢,只要令 \max(mx_u,mx_v)=mx_u 并统计这样的点对数量不就好了,将当前连通块内的所有点按照 mx 值从小到大排序,上数据结构维护所有满足 mx_v\le mx_u 的点 v 即可。因此,当前点 u 对答案的贡献为 mx_u\times cnt。其中 cnt 代表满足长度条件 L-k_u\le k_v\le R-k_u 的点 v 数量。

需要一个支持以下操作的数据结构单点加区间查的数据结构,树状数组即可,注意树状数组中下标不能为 0

总时间复杂度 O(n\log^2 n)

::::success[AC Code]

#include<bits/stdc++.h>
#define int long long
using namespace std;
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar();
    return x*f;
}
const int maxn=1e5+7;
int n,L,R,sz[maxn],w[maxn],cnt,root;
bool vis[maxn];
vector<pair<int,int> > g[maxn];
struct MoonHalo
{
    int maxv,k;
} p[maxn];
int tot;
struct BIT
{
    int c[maxn*5];
    int lowbit(int x){return x&(-x);}
    void add(int x,int y)
    {
        if(x<0) return;
        for(x++;x<=n+5;x+=lowbit(x)) c[x]+=y;
    }
    int query(int x)
    {
        if(x<0) return 0;
        int res=0;
        for(x++;x;x-=lowbit(x)) res+=c[x];
        return res;
    }
}bit;
int ans=0;
void dfs(int u,int f,int maxv,int len)
{
    p[++tot]={maxv,len};
    for(auto nx:g[u])
    {
        int v=nx.first,w=nx.second;
        if(v==f||vis[v]) continue;
        dfs(v,u,max(maxv,w),len+1);
    }
}
int calc(int u,int mxx,int len)
{
    tot=0;
    dfs(u,0,mxx,len);
    sort(p+1,p+tot+1,[](MoonHalo a,MoonHalo b){return a.maxv<b.maxv;});
    int res=0;
    for(int i=1;i<=tot;i++)
    {
        int l=L-p[i].k,r=R-p[i].k;
        if(r>=0) res+=(bit.query(r)-bit.query(l-1))*p[i].maxv;
        bit.add(p[i].k,1);
    }
    for(int i=1;i<=tot;i++) bit.add(p[i].k,-1);
    return res;
}
void dfssz(int u,int f)
{
    sz[u]=1;
    for(auto nx:g[u])
    {
        int v=nx.first;
        if(v==f||vis[v]) continue;
        dfssz(v,u);
        sz[u]+=sz[v];
    }
}
void getroot(int u,int f)
{
    sz[u]=1;
    w[u]=0;
    for(auto nx:g[u])
    {
        int v=nx.first;
        if(v==f||vis[v]) continue;
        getroot(v,u);
        sz[u]+=sz[v];
        w[u]=max(w[u],sz[v]);
    }
    w[u]=max(w[u],cnt-sz[u]);
    if(w[u]<w[root]) root=u;
}
void divide(int u)
{
    dfssz(u,0);
    cnt=sz[u],root=0;
    getroot(u,0);
    u=root;
    vis[u]=true;
    for(auto nx:g[u])
    {
        int v=nx.first,w=nx.second;
        if(vis[v]) continue;
        ans-=calc(v,w,1);
    }
    ans+=calc(u,0,0);
    for(auto nx:g[u])
    {
        int v=nx.first,w=nx.second;
        if(vis[v]) continue;
        divide(v);
    }
}
signed main()
{
    n=read(),L=read(),R=read();
    for(int i=1;i<n;i++)
    {
        int u,v,w;
        u=read(),v=read(),w=read();
        g[u].push_back({v,w});
        g[v].push_back({u,w});
    }
    w[0]=INT_MAX;
    divide(1);
    cout<<ans*2;
    return 0;
}

::::