P5327 [ZJOI2019]语言(线段树合并+树上差分)

· · 题解

P5327 [ZJOI2019]语言

思路1:每个点能到达的点集为所有经过它的路径的并集。对于每个点,将经过它的路径上的点权赋为 1 ,全局求和,复杂度 O(mn\log n)

思路2:在一条链上,对每个点维护一下它能到的最左端和最右端。每个询问相当于区间取min/max。

考虑优化思路 1 ,借鉴雨天的尾巴的思路,将询问差分,用线段树合并维护经过这个点的所有路径。

又发现每个点能到达的点集为包含所有经过它的路径的端点的极小连通块,即一棵树。再借鉴寻宝游戏的思路,这个极小连通块的点数-1=边数即为按 dfn 排序后总 dep 减去相邻两点间 lca 的 dep,用线段树维护这个极小连通块的大小即可;具体地,线段树上每个节点维护这个区间内的距离和、dfn最小/最大的点。

如果写 st表求 lca 的话复杂度为 O(n\log n)

所以这题其实挺套路的,就是雨天的尾巴+寻宝游戏二合一

#include<bits/stdc++.h>
#define mid ((l+r)>>1)
#define ls t[ro].l
#define rs t[ro].r
using namespace std;
typedef long long ll;
namespace FGF
{
    int n,m;
    const int N=2e5+5;
    vector<int> g[N];
    int dep[N],st[23][N],dfn[N],num,fa[N],lo[N],rt[N],cnt;
    int Min(int u,int v){return dep[u]<dep[v]?u:v;}
    ll ans;
    struct tree{
        int l,r,mn,mx,cnt;
        ll sum;
    }t[N*40];
    void dfs(int u,int f)
    {
        dep[u]=dep[f]+1,fa[u]=f,dfn[u]=++num,st[0][num]=u;
        for(auto v:g[u])
            if(v!=f)dfs(v,u),st[0][++num]=u;
    }
    void init()
    {
        lo[0]=-1;
        for(int i=1;i<=num;i++)
            lo[i]=lo[i>>1]+1;
        for(int j=1;j<=lo[num];j++)
            for(int i=1;i+(1<<j)-1<=num;i++)
                st[j][i]=Min(st[j-1][i],st[j-1][i+(1<<(j-1))]);
    }
    int getlca(int l,int r)
    {
        if(!l||!r)return 0;
        if(l>r)swap(l,r);
        int k=lo[r-l+1];
        return Min(st[k][l],st[k][r-(1<<k)+1]);
    }
    inline void pushup(int ro)
    {
        t[ro].mn=t[ls].mn?t[ls].mn:t[rs].mn,t[ro].mx=t[rs].mx?t[rs].mx:t[ls].mx;
        t[ro].sum=t[ls].sum+t[rs].sum-dep[getlca(t[ls].mx,t[rs].mn)];
    }
    void inser(int &ro,int l,int r,int x,int op)
    {
        if(!ro)ro=++cnt;
        if(l==r)
        {
            t[ro].cnt+=op;
            t[ro].mn=t[ro].mx=(t[ro].cnt>0?l:0),t[ro].sum=(t[ro].cnt>0?dep[x]:0);
            return;
        }
        dfn[x]<=mid?inser(ls,l,mid,x,op):inser(rs,mid+1,r,x,op);
        pushup(ro);
    }
    int merge(int x,int y,int l,int r)
    {
        if(!x||!y)return x+y;
        if(l==r)
        {
            t[x].cnt+=t[y].cnt;
            t[x].mn=t[x].mx=(t[x].cnt>0?l:0),t[x].sum=(t[x].cnt>0?dep[st[0][l]]:0);
            return x;
        }
        t[x].l=merge(t[x].l,t[y].l,l,mid),t[x].r=merge(t[x].r,t[y].r,mid+1,r);
        pushup(x);
        return x;
    }
    void dfs2(int u,int f)
    {
        for(auto v:g[u])
            if(v!=f)dfs2(v,u),rt[u]=merge(rt[u],rt[v],1,num);
        ans+=t[rt[u]].sum-dep[getlca(t[rt[u]].mn,t[rt[u]].mx)];
    }
    void work()
    {
        scanf("%d%d",&n,&m);
        for(int i=1,u,v;i<n;i++)
            scanf("%d%d",&u,&v),g[u].push_back(v),g[v].push_back(u);
        dfs(1,0),init();
        for(int i=1,s,t,f;i<=m;i++)
        {
            scanf("%d%d",&s,&t);f=getlca(dfn[s],dfn[t]);
            inser(rt[s],1,num,s,1),inser(rt[s],1,num,t,1);
            inser(rt[t],1,num,s,1),inser(rt[t],1,num,t,1);
            inser(rt[f],1,num,s,-1),inser(rt[f],1,num,t,-1);
            if(fa[f])inser(rt[fa[f]],1,num,s,-1),inser(rt[fa[f]],1,num,t,-1);
        }
        dfs2(1,0);
        printf("%lld",ans/2);
    }
}
int main()
{
    FGF::work();
    return 0;
}