P11364 [NOIP2024] 树上查询 题解

· · 题解

P11364 NOIP 2024 T4 树上查询 非完美算法 -> 正解

简明题意

一棵树,每次给一个区间 [l, r] 和一个参数 k,要求是找到区间中子区间的所有点的 LCA 的深度最大。

似乎没有什么头绪,没事,可以从部分分做起。

8ptsq, n \leq 500

这部分怎么写呢。我们注意到,如果一个长度为 k 的区间,若是往后再加一个节点,不管怎么样,LCA 一定不会变大,所以说我们只需要枚举每一个长度为 k 的区间,算一遍 LCA 即可。 时间复杂度 O(qn ^ 2).

代码如下:

void sub1(){//8pts
    int n, q;
    cin >> n;
    for(int i = 1, x, y; i < n; i++)cin >> x >> y, adde(x, y), adde(y, x);
    dfs(1, 0);
    cin >> q;
    while(q--){
        int l, r, k, res = 0;
        cin >> l >> r >> k;
        for(int i = l; i <= r && i + k - 1 <= r; i++){
            int lca = i;
            // cout << i << " ";
            for(int j = i + 1; j <= (i + k - 1); j++){
                // cout << j << " ";
                lca = LCA(lca, j);
            }
            // cout << "---" << lca << " " << dep[lca];
            // cout << endl;
            res = max(res, dep[lca]);
        }
        cout << res << endl;
    }
}

20ptsq, n \leq 5000

这个范围明显是留给 O(qn \log n) 这类复杂度的。我们发现所选段是连续的,所以我们可以预处理。我们把之前的求出来的都存起来,很明显 f_{i,j} = LCA(f_{i, j - 1}, j),然后我们就可以在枚举每个段的时候 O(1) 求解。时间复杂度 O(n ^ 2\log n + qn)

代码如下:

void sub2(){//20pts
    int n, q;
    cin >> n;
    for(int i = 1, x, y; i < n; i++)cin >> x >> y, adde(x, y), adde(y, x);
    dfs(1, 0);
    cin >> q;
    for(int i = 1; i <= n; i++)f[i][i] = i;
    for(int i = 1; i <= n; i++){
        for(int j = i + 1; j <= n; j++){
            f[i][j] = LCA(f[i][j - 1], j);
        }
    }
    while(q--){
        int l, r, k;
        cin >> l >> r >> k;
        int res = 0;
        for(int i = l; i + k - 1 <= r; i++){
            res = max(res, dep[f[i][i + k - 1]]);
        }
        cout << res << endl;
    }
}

52pts 特殊性质 A

首先我们要注意到一个性质:

\max_{l\le l'\le r'\le r \land r'-l'+1\ge k}\text{dep}_ {\text{LCA*}(l', r')} = \min_{l \leq i \le r}dep_{LCA(i, i + 1)}

这个结论怎么证明呢。假设所有点的 LCA = u,那么至少有两个点在 u 的不同子树上,这说明什么,说明只需要两个点就可以确定了,如果有多个点不在同一个子树,那么一定有两个点可以把 LCA 升到 u 的位置。

知道了这个结论,再加上是特殊性质,我们考虑的对象从树上转化为了数组,考虑到可以找一个数据结构,支持插入一个点或者是删除一个点,这种情况下最后的答案就可以 整体二分 + 查询区间内最长连续段的长度是否大于 k,但是很显然,我不会整体二分。还有大佬跟我说可以用主席树去完成,我是既没搞懂原理,主席树我也不熟练。NOIP 官方的解题报告上面写着可以用 LCA + ST 表解决,但是我翻遍网上的题解,问了很多 AI,都不知道这部分怎么写。所以说这部分的代码暂时搁一下,以后一定会来写的。

不过呢,在 AI 的帮助下,我有一种可以通过 q, n \leq 1e5 的方法:

考虑每对点的 LCA 都有一个影响范围,这个样子我们就可以用单调栈去维护一个影响范围,然后用线段树维护区间 max

我们还要考虑到时间复杂度的问题,若是 k 特别大,我们全部预处理出来肯定不可能,所以我们考虑只枚举到 k = 100,而对于大的点我们就只考虑影响力大的点即可。(说白了就是暴力骗分)

代码如下:(部分代码参考 AI,出现了很多奇怪的语法)

void bfs(){
    for(int i = 1; i <= n; i++) dep[i] = 0;
    queue<int> q;
    q.push(1);
    dep[1] = 1;
    while(!q.empty()){
        int u = q.front();
        q.pop();
        for(int v : g[u]){
            if(dep[v] == 0){
                dep[v] = dep[u] + 1;
                q.push(v);
            }
        }
    }
    for(int i = 1; i <= n; i++) A[i] = dep[i];
}

void compLR(){
    stack<int> st;
    for(int i = 1; i <= n; i++){
        while(!st.empty() && A[st.top()] > A[i]){
            R[st.top()] = i;
            st.pop();
        }
        if(st.empty()) L[i] = 0;
        else L[i] = st.top();
        st.push(i);
    }
    while(!st.empty()){
        R[st.top()] = n + 1;
        st.pop();
    }
    for(int i = 1; i <= n; i++){
        Len[i] = R[i] - L[i] - 1;
    }
}
void sub3(){//40pts
    scanf("%d", &n);
    for(int i = 1; i <= n; i++) g[i].clear();
    for(int i = 0; i < n - 1; i++){
        int u, v;
        scanf("%d %d", &u, &v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    bfs();
    compLR();
    ST1 = new SegTree(n);
    ST1->build(1, 1, n, A);
    for(int k = 2; k <= K0; k++){
        int m = n - k + 1;
        int* B = new int[m + 1];
        for(int i = 1; i <= m; i++){
            int mn = A[i];
            for(int j = i + 1; j < i + k; j++){
                if(A[j] < mn) mn = A[j];
            }
            B[i] = mn;
        }
        STk[k] = new SegTree(m);
        STk[k]->build(1, 1, m, B);
        delete[] B;
    }
    vector<tuple<int, int, int, int>> lst;
    for(int i = 1; i <= n; i++){
        if(Len[i] > K0){
            lst.push_back({Len[i], i, L[i], R[i]});
        }
    }
    sort(lst.begin(), lst.end(), greater<tuple<int, int, int, int>>());
    int q;
    scanf("%d", &q);
    while(q--){
        int l, r, k;
        scanf("%d %d %d", &l, &r, &k);
        if(k == 1){
            int ans = ST1->query(1, 1, n, l, r);
            printf("%d\n", ans);
        } 
        else if(k <= K0){
            int le = l;
            int ri = r - k + 1;
            if(le > ri){
                int ans = -1;
                for(int i = l; i <= r; i++){
                    if(A[i] > ans) ans = A[i];
                }
                printf("%d\n", ans);
            } 
            else{
                int ans = STk[k]->query(1, 1, n - k + 1, le, ri);
                printf("%d\n", ans);
            }
        } 
        else{
            int ans = -1;
            for(auto& t : lst){
                int ln = get<0>(t);
                if(ln < k) break;
                int idx = get<1>(t);
                int Li = get<2>(t);
                int Ri = get<3>(t);
                if(l <= idx && idx <= r && Li <= r - k && l <= Ri - k){
                    if(A[idx] > ans) ans = A[idx];
                }
            }
            printf("%d\n", ans);
        }
    }
    delete ST1;
    for(int k = 2; k <= K0; k++){
        delete STk[k];
    }
    return 0;
}

64pts 特殊性质 B

考虑到只需要计算一次即可得出答案,我们怎么做呢,对于一整个区间来说,他的公共 LCA 一定是 dfs 序最小的和 dfs 序最大的的 LCA。证明如下:

u = LCA(u_{dfsmin}, u _ {dfsmax}),因为 dfs 顺序的问题,遍历顺序的先后保证了其间的节点一定在 u 的子树里。

依据此我们就能写出一份代码:

while (q--) {//省略了st表和dfs的部分
        int l, r, k;
        scanf("%d %d %d", &l, &r, &k);
        if (k == 1) {
            printf("%d\n", query_dep_max(l, r));
        } else {
            int M = r - k + 1;
            int ans = 0;
            for (int l0 = l; l0 <= M; l0++) {
                int r0 = l0 + k - 1;
                int u = query_min(l0, r0);
                int v = query_max(l0, r0);
                int w = lca(u, v);
                int d = dep[w];
                if (d > ans) ans = d;
            }
            printf("%d\n", ans);
        }
    }

正解

铺垫了这么多,终于讲到正解了。

我们在研究特殊性质 A 的时候发现两个结论

  • \max_{l\le l'\le r'\le r \land r'-l'+1\ge k}\text{dep}_ {\text{LCA*}(l', r')} = \min_{l \leq i \le r}dep_{LCA(i, i + 1)}
  • 考虑每对点的 LCA 都有一个影响范围,这个样子我们就可以用单调栈去维护一个影响范围

这给我们什么思考方向,可以离线去做 cdq,考虑每个 LCA 的影响范围 (L, R) 会对哪些询问 l, r 有贡献,这个是非常好看出来的: L \le r - k + 1 R \geq l + k - 1 R - L + 1 \geq k

这个很显然就是 cdq 了。给出代码:(正解借鉴了@qzmoot的题解)

#include <bits/stdc++.h>
#define Max(a, b) ((a) < (b)) ? (b) : (a)
using namespace std;
const int N = 1e6 + 10;
const int M = 1e6 + 10;
int n, q;
int nxt[M], to[M], h[M], ecnt;
void adde(int u, int v){
    nxt[++ecnt] = h[u];
    to[ecnt] = v;
    h[u] = ecnt;
} 
int F[N][30], dep[N];
void dfs(int u, int fa){
    dep[u] = dep[fa] + 1;
    F[u][0] = fa;
    for(int i = 1; i <= 25; i++){
        F[u][i] = F[F[u][i - 1]][i - 1];
    }
    for(int i = h[u]; i; i = nxt[i])
        if(to[i] != fa)
            dfs(to[i], u);
}
int LCA(int x, int y){
    if(dep[x] < dep[y])swap(x, y);
    for(int i = 25; i >= 0; i--){
        if(dep[y] + (1 << i) <= dep[x]) x = F[x][i];
    }
    if(x == y)return x;
    for(int i = 25; i >= 0; i--){
        if(F[x][i] != F[y][i])x = F[x][i], y = F[y][i];
    }
    return F[x][0];
}

void sub1(){//8pts
    cin >> n;
    for(int i = 1, x, y; i < n; i++)cin >> x >> y, adde(x, y), adde(y, x);
    dfs(1, 0);
    cin >> q;
    while(q--){
        int l, r, k, res = 0;
        cin >> l >> r >> k;
        for(int i = l; i <= r && i + k - 1 <= r; i++){
            int lca = i;
            // cout << i << " ";
            for(int j = i + 1; j <= (i + k - 1); j++){
                // cout << j << " ";
                lca = LCA(lca, j);
            }
            // cout << "---" << lca << " " << dep[lca];
            // cout << endl;
            res = Max(res, dep[lca]);
        }
        cout << res << endl;
    }
}
void sub2(){//20pts
    cin >> n;int f[5005][5005];
    for(int i = 1, x, y; i < n; i++)cin >> x >> y, adde(x, y), adde(y, x);
    dfs(1, 0);
    cin >> q;
    for(int i = 1; i <= n; i++)f[i][i] = i;
    for(int i = 1; i <= n; i++){
        for(int j = i + 1; j <= n; j++){
            f[i][j] = LCA(f[i][j - 1], j);
        }
    }
    while(q--){
        int l, r, k;
        cin >> l >> r >> k;
        int res = 0;
        for(int i = l; i + k - 1 <= r; i++){
            res = Max(res, dep[f[i][i + k - 1]]);
        }
        cout << res << endl;
    }
}
struct BCJ{
    int b[N];
    int lowbit(int x){return (x & (-x));}
    void add(int x, int val){
        if(!val)return;
        while(x <= n){
            b[x] = max(b[x], val), x += lowbit(x);
        }
    }
    void del(int x){
        while(x <= n)b[x] = 0, x += lowbit(x);
    }
    int qry(int x){
        int res = 0;
        while(x){
            res = max(b[x], res);
            x -= lowbit(x);
        }
        return res;
    }
}b;
struct Node{
    int id, l, r, len, dep, ans, pos;
    Node(){id = 0, l = 0, r = 0, len = 0, dep = 0, ans = 0, pos = 0;}
}a[N], a1[N];
struct Segment_Tree{
    #define lc (p << 1)
    #define rc (p << 1 | 1)
    #define mid ((l + r) >> 1)
    int tree[N << 2];
    void init(){
        memset(tree, 0, sizeof(tree));
    }
    void up(int p){tree[p] = max(tree[lc], tree[rc]);}
    void build(int p, int l, int r){
        // cout << l << " " << r << " " << p << endl;
        if(l > r)return;
        if(l == r){
            tree[p] = dep[l];
            return;
        }
        build(lc, l, mid);build(rc, mid + 1, r);
        up(p);
    }
    int query(int p, int l, int r, int ql, int qr){
        if(ql <= l && r <= qr){
            return tree[p];
        }
        int res = 0;
        if(ql <= mid)res = Max(res, query(lc, l, mid, ql, qr));
        if(qr > mid)res = Max(res, query(rc, mid + 1, r, ql, qr));
        return res;
    }
    #undef lc 
    #undef rc
    #undef mid
}tr;
int ccnt = 0;
int stk[N], tot, ans[N];
bool cmp1(Node x, Node y){
    if(x.len == y.len)return x.pos < y.pos;
    return x.len > y.len;
}
bool cmp2(Node x, Node y){
    if(x.r == y.r)return x.pos < y.pos;
    return x.r > y.r;
}
void cdq(int l, int r){
    if(l == r)return;
    int mid = (l + r) >> 1, tot1 = 0;

    for(int i = l; i <= mid; i++)if(a[i].pos == 0)a1[++tot1] = a[i];
    for(int i = mid + 1; i <= r; i++)if(a[i].pos)a1[++tot1] = a[i];
    sort(a1 + 1, a1 + tot1 + 1, cmp2);
    for(int i = 1; i <= tot1; i++){
        if(!a1[i].pos){
            b.add(a1[i].l, a1[i].dep);
        }
        else ans[a1[i].pos] = max(ans[a1[i].pos], b.qry(a1[i].l));
    }
    for(int i = 1; i <= tot1; i++){
        if(a1[i].pos == 0)b.del(a1[i].l);
    }cdq(l, mid), cdq(mid + 1, r);
}
void sub3(){//100pts
    cin >> n;
    for(int i = 1, x, y; i < n; i++)cin >> x >> y, adde(x, y), adde(y, x);
    dfs(1, 0); 
    tr.init();
    tr.build(1, 1, n);
    ccnt = 0;
    for(int i = 1; i < n; i++){
        Node t;
        t.id = LCA(i, i + 1), t.dep = dep[t.id];
        a[++ccnt] = t;
    }
    // 单调栈处理
    a[a[1].id].l = 1, stk[++tot] = 1;
    for(int i = 2; i < n; i++){
        int las = i;
        while(tot && a[stk[tot]].dep >= a[i].dep)las = a[stk[tot]].l, tot--;
        stk[++tot] = i;
        a[i].l = las;
    }
    a[n - 1].r = n - 1;
    tot = 0;
    memset(stk, 0, sizeof(stk));
    // cout << 114514;
    stk[++tot] = n - 1;
    for(int i = n - 2; i >= 1; i--){
        int las = i;
        while(tot && a[stk[tot]].dep >= a[i].dep)las = a[stk[tot]].r, tot--;
        stk[++tot] = i, a[i].r = las;
    }
    for(int i = 1; i <= ccnt; i++){
        a[i].r++, a[i].len = a[i].r - a[i].l + 1;
    }
    cin >> q;
    for(int i = 1; i <= q; i++){
        int ql, qr, k;
        cin >> ql >> qr >> k;
        if(k != 1){
            ccnt++, a[ccnt].id = 0, a[ccnt].l = qr - k + 1, a[ccnt].r = ql + k - 1; a[ccnt].len = k, a[ccnt].dep = a[ccnt].ans = 0, a[ccnt].pos = i;
        }  
        else ans[i] = (ql == qr) ? dep[ql] : tr.query(1, 1, n, ql, qr);
    }
    sort(a + 1, a + ccnt + 1, cmp1);
    cdq(1, ccnt);
    for(int i = 1; i <= q; i++)cout << ans[i] << "\n";
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);cout.tie(nullptr);
    // sub1();
    // sub2();
    sub3();
    return 0;
}