P2633 Count on a tree

· · 题解

题面

给一棵树,每个点有点权。m 次询问 u,v,k,求 u\to v 的路径上所有点的点权中,第 k 小的那个。

询问强制在线:u\;xor\;lst\to u

题解

区间第 k 小可以用主席树加区间前缀和解决。具体而言,如果 [1,mid] 的值域内,前 x 个数有 s1 个,前 y 个数里有 s2 个,则 [x+1,y] 内在值域内的数,有 s2-s1 个。

延展到树上,我们每个节点从它的父节点版本拓展出来,构建主席树。于是 rt[u] 版本的权值线段树,代表的就是 u1 这条路径上点塞入权值线段树。

类似于区间第 k 小的差分,我们在树上也可以差分。

u,v 路径上的第 k 小,则运用公式:sum[u]+sum[v]-sum[lca]-sum[fa[lca]] 即可判断是在左区间还是右区间。

CODE:


#include<bits/stdc++.h>
#define ls(p) tr[p].ls
#define rs(p) tr[p].rs
using namespace std;
const int N=1e5+5;
int n,m,lst,a[N],b[N],c[N],ctp;
int rt[N],tot,g[N][22],dep[N];
vector<int> e[N];
struct NODE{
    int ls,rs,sum;
}tr[N*22];
void ins(int v,int &u,int l,int r,int x){
    //v 依赖版本,u 新开版本 
    u=++tot;
    tr[u]=tr[v], tr[u].sum=tr[v].sum+1;
    if(l==r) return;
    int mid=(l+r)>>1;
    if(x<=mid) ins(ls(v),ls(u),l,mid,x);
    else ins(rs(v),rs(u),mid+1,r,x);
}
int query(int u,int v,int x,int y,int l,int r,int k){
    //lca(u,v)=x, y=fa[x], 当前权值范围:[l,r]
    if(l==r) return l;
    int s=tr[ls(u)].sum+tr[ls(v)].sum-tr[ls(x)].sum-tr[ls(y)].sum;
    int mid=(l+r)>>1;
    if(k<=s) return query(ls(u),ls(v),ls(x),ls(y),l,mid,k);
    return          query(rs(u),rs(v),rs(x),rs(y),mid+1,r,k-s);
}
void dfs(int u,int fa){
    dep[u]=dep[fa]+1;
    g[u][0]=fa;
    ins(rt[fa],rt[u],1,ctp,a[u]);
    for(int i=1;i<=20;i++){
        g[u][i]=g[g[u][i-1]][i-1];
    }
    for(auto v:e[u]){
        if(v==fa) continue;
        dfs(v,u);
    }
}
int LCA(int x,int y){
    if(dep[x]<dep[y]) swap(x,y);
    for(int i=20;i>=0;i--){
        if(dep[g[x][i]]>=dep[y]) x=g[x][i];
    }
    if(x==y) return x;
    for(int i=20;i>=0;i--){
        if(g[x][i]!=g[y][i]){
            x=g[x][i];
            y=g[y][i];
        }
    }
    return g[x][0];
}
signed main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    cin>>n>>m;
    for(int i=1;i<=n;i++){
        cin>>a[i];
        b[i]=a[i];
    }
    sort(b+1,b+1+n);
    for(int i=1;i<=n;i++){
        if(i==1 || b[i]!=b[i-1]){
            c[++ctp]=b[i];
        }
    }
    for(int i=1;i<=n;i++){
        int x=lower_bound(c+1,c+1+ctp,a[i])-c;
        a[i]=x;
    }
    for(int i=1;i<n;i++){
        int x,y; cin>>x>>y;
        e[x].push_back(y);
        e[y].push_back(x);
    }
    dfs(1,0);
    while(m--){
        int x,y,k,lca;
        cin>>x>>y>>k;
        x=(x^lst), lca=LCA(x,y);
        lst=c[query(rt[x],rt[y],rt[lca],rt[g[lca][0]],1,ctp,k)];
        cout<<lst<<'\n';
    }
    return 0;
}