P11364 [NOIP2024] 树上查询 题解

· · 题解

c_i 为结点 i 和结点 i+1 的最近公共祖先,那么有一个性质是:当 l<r 时,\text{LCA*}(l,r) 等于 c_l\sim c_{r-1} 中深度最小的结点,且这样的结点唯一。

证明考虑分类讨论:

通过这个结论,我们可以知道,当 l<r 时:

dep_{\text{LCA*}(l,r)}=\min_{i=l}^{r-1} dep_{c_i}

于是我们只需要提前预处理出来每个 dep_{c_i} 就可以把树上问题转化为序列问题了!

a_i=dep_{c_i},那么我们可以对于每一个位置 i 求出最长的区间 [x_i,y_i] 满足 i\in[x_i,y_i]\min\limits_{j=x_i}^{y_i-1} a_j=a_i。于是现在的问题变成了,求出所有与 [l_j,r_j] 的交集大小至少为 k_j[x_i,y_i]a_i 的最大值。

再分类讨论一下:

对于第一种情况,可以将 (l_j,r_j,k_j) 按照 r_j-k_j+1 从小到大排序,将 (x_i,y_i,a_i) 按照 x_i 从小到大排序,做一次扫描线,用线段树对满足 r_j \le y_ii 统计 a_i 的最大值。

对于第二种情况,可以将 (l_j,r_j,k_j) 按照 k_j 从大到小排序,将 (x_i,y_i,a_i) 按照 y_i-x_i+1 从大到小排序,做一次扫描线,用线段树对满足 l_j+k_j-1\le y_i\lt r_ji 统计 a_i 的最大值。

注意特判 k=1 的情况,用 ST 表维护。这里认为 n,q 同阶,时间复杂度 \mathcal O(n \log n)

#include <bits/stdc++.h>

#define ll long long
#define ull unsigned long long
#define i128 __int128
#define endl '\n'
#define pb push_back
#define pf push_front
#define pii pair<int,int>
#define fi first
#define se second
#define vei vector<int>
#define pq priority_queue
#define lb lower_bound
#define ub upper_bound
#define yes puts("yes")
#define no puts("no")
#define Yes puts("Yes")
#define No puts("No")
#define YES puts("YES")
#define NO puts("NO")
#define In(x) freopen(x".in","r",stdin)
#define Out(x) freopen(x".out","w",stdout)
#define File(x) (In(x),Out(x))
using namespace std;
const int N=5e5+5,L=19,inf=1e9;
int n,q,fa[N][L],dep[N],a[N],st[N][L],pw[L],lg[N],val[N<<2];
struct P{
    int x,y,a;
}s[N];
struct Q{
    int l,r,k,ans,id;
}t[N];
bool cmps1(P a,P b){
    return a.x<b.x;
}
bool cmpt1(Q a,Q b){
    return a.r-a.k<b.r-b.k;
}
bool cmps2(P a,P b){
    return a.y-a.x>b.y-b.x;
}
bool cmpt2(Q a,Q b){
    return a.k>b.k;
}
bool cmpt3(Q a,Q b){
    return a.id<b.id;
}
vector <int> ve[N];
void init(int u,int f){
    fa[u][0]=f,dep[u]=dep[f]+1;
    for(auto v:ve[u]){
        if(v==f) continue;
        init(v,u);
    }
}
int lca(int u,int v){
    if(dep[u]<dep[v]) swap(u,v);
    for(int i=L-1;i>=0;i--) if(dep[fa[u][i]]>=dep[v]) u=fa[u][i];
    if(u==v) return u;
    for(int i=L-1;i>=0;i--) if(fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i];
    return fa[u][0];
}
int askst1(int l,int r){
    int k=lg[r-l+1];
    return min(st[l][k],st[r-pw[k]+1][k]);
}
int askst2(int l,int r){
    int k=lg[r-l+1];
    return max(st[l][k],st[r-pw[k]+1][k]);
}
#define ls (g<<1)
#define rs (g<<1|1)
#define mid ((l+r)>>1)
void upd(int g){
    val[g]=max(val[ls],val[rs]);
}
void build(int g,int l,int r){
    if(l==r) return val[g]=0,void();
    build(ls,l,mid);
    build(rs,mid+1,r);
    upd(g);
}
void modify(int g,int l,int r,int x,int v){
    if(l==x&&x==r) return val[g]=max(v,val[g]),void();
    if(r<x||x<l) return;
    modify(ls,l,mid,x,v);
    modify(rs,mid+1,r,x,v);
    upd(g);
}
int ask(int g,int l,int r,int x,int y){
    if(x<=l&&r<=y) return val[g];
    if(r<x||y<l) return 0;
    return max(ask(ls,l,mid,x,y),ask(rs,mid+1,r,x,y));
}
void solve(){
    cin>>n;
    for(int i=1;i<n;i++){
        int u,v;
        cin>>u>>v;
        ve[u].pb(v);
        ve[v].pb(u);
    }
    pw[0]=1,lg[1]=0,init(1,0);
    for(int i=1;i<L;i++) pw[i]=pw[i-1]*2;
    for(int i=2;i<=n;i++) lg[i]=lg[i/2]+1;
    for(int i=1;i<L;i++) for(int u=1;u<=n;u++) fa[u][i]=fa[fa[u][i-1]][i-1];
    for(int i=1;i<n;i++) st[i][0]=a[i]=dep[lca(i,i+1)];
    for(int i=1;i<L;i++) for(int u=1;u<n;u++) st[u][i]=min(st[u][i-1],u+pw[i-1]<n?st[u+pw[i-1]][i-1]:inf);
    for(int u=1;u<n;u++){
        s[u].a=a[u];
        int l=1,r=u;
        while(l<r){
            int m=(l+r)>>1;
            if(askst1(m,u)<a[u]) l=m+1;
            else r=m;
        }
        s[u].x=l;
        l=u,r=n-1;
        while(l<r){
            int m=(l+r+1)>>1;
            if(askst1(u,m)<a[u]) r=m-1;
            else l=m;
        }
        s[u].y=l+1;
    }
    cin>>q;
    for(int i=1;i<=q;i++) cin>>t[i].l>>t[i].r>>t[i].k,t[i].id=i;
    sort(s+1,s+n,cmps1);
    sort(t+1,t+q+1,cmpt1);
    build(1,1,n);
    for(int i=1,j=1;j<=q;j++){
        if(t[j].k==1) continue;
        while(i<n&&s[i].x<=t[j].r-t[j].k+1) modify(1,1,n,s[i].y,s[i].a),i++;
        t[j].ans=ask(1,1,n,t[j].r,n);
    }
    sort(s+1,s+n,cmps2);
    sort(t+1,t+q+1,cmpt2);
    build(1,1,n);
    for(int i=1,j=1;j<=q;j++){
        if(t[j].k==1) continue;
        while(i<n&&s[i].y-s[i].x+1>=t[j].k) modify(1,1,n,s[i].y,s[i].a),i++;
        t[j].ans=max(t[j].ans,ask(1,1,n,t[j].l+t[j].k-1,t[j].r-1));
    }
    sort(t+1,t+q+1,cmpt3);
    for(int i=1;i<n;i++) st[i][0]=dep[i];
    for(int i=1;i<L;i++) for(int u=1;u<n;u++) st[u][i]=max(st[u][i-1],u+pw[i-1]<n?st[u+pw[i-1]][i-1]:0);
    for(int i=1;i<=q;i++) if(t[i].k==1) t[i].ans=askst2(t[i].l,t[i].r);
    for(int i=1;i<=q;i++) cout<<t[i].ans<<endl;
}
signed main(){
    ios::sync_with_stdio(0);
    signed T=1;
//  cin>>T;
    while(T--) solve();
    return 0;
}