题解:P14211 [ROI 2016 Day2] 快递服务

· · 题解

神秘大常数 O(n\log^2n) 做法。

首先我们考虑,路径 (a_1,b_1),(a_2,b_2) 的交一定还是路径,假设为 (u,v),其只有祖孙链和非祖孙链两种形式,不妨令 dep_u \ge dep_v,分讨:

枚举 u,贪心地,我们一定会选一端在 u 子树内,一端在 u 子树外的路径中,lca 最浅的两条。所以求出 lca 后这部分可以 O(n) 解决。

那么两路径一定 lca 相同,且两端所在 lca 的子树也相同。我们枚举 lca,然后对每组所处子树相同的路径计算答案。

假设当前 lca 为 c,且 a_1,a_2c 的儿子 x 的子树中,b_1,b_2c 的儿子 y 的子树中。考虑两条路径会在两棵子树内各产生一个交点,我们枚举在 x 子树内的 u,然后在 u 子树合并时计算 a_1,a_2 对应产生的贡献。

假设当前 u 子树内已经维护了 A 集合内的路径,当前要加入路径 (a_2,b_2)。那么我们考虑贪心地,v 一定是 b_2 的祖先中,最深的子树内有 A 内节点的点。显然有二段性,我们可以对 b_2 的祖先二分求出这个点。

现在问题是如何快速查询一个点子树内是否有 A 内的点。我们对 A 内点 +1,于是子树内有点等价于子树和非零,所以可以树状数组维护单点加区间查。

u 的枚举,考虑显然只需枚举 x 子树内路径端点虚树上的点即可。加入路径 (a_2,b_2) 使用树上启发式合并,保存重儿子的树状数组,暴力单点加轻儿子。于是修改查询复杂度均为 O(n\log^2n)

代码赛时写的比较猎奇,将就看一下(。

::::info[code]

#include <bits/stdc++.h>
using namespace std;

using LL = long long;
namespace IO
{
    constexpr int S = 1 << 26;
    char b1[S], *p1, *p2;
    char b2[S], *p3 = b2;

    #define G (p1 == p2 && (p2 = (p1 = b1) + fread(b1, 1, S, stdin)), p1 == p2 ? EOF : *p1 ++)
    #define P (p3 == b2 + S && (p3 = b2, fwrite(b2, S, 1, stdout)), *p3 ++)

    int R()
    {
        int x = 0, c = G;
        while (!isdigit(c)) c = G;
        while (isdigit(c)) x = x * 10 + (c & 15), c = G;
        return x;
    }

    void W(LL x)
    {
        static char stk[20]; int top = 0;
        do stk[top ++] = x % 10 | '0', x /= 10; while (x);
        while (top --) P = stk[top];
        P = ' ';
    }

    void flush() {fwrite(b2, p3 - b2, 1, stdout), p3 = b2;}

    #undef G
    #undef P
}
using IO::R;
using IO::W;

constexpr int N = 2e5 + 5, L = 18;

int n, m;
int par[N], deg[N], gra[N];
array <int, 2> pat[N];
int dfn[N], dep[N], spt[L][N];
array <int, 4> tag[N];
vector <int> idx[N];

vector <int> trl[N], trr[N], ano[N];
int siz[N], son[N];
int tin[N], tou[N], anc[N][L], bit[N];
int res, resa, resb;
int ans, ansa, ansb;
int tmp, tmpa, tmpb;

void dfs(int u)
{
    spt[0][dfn[u] = ++dfn[0]] = u, dep[u] = dep[par[u]] + 1;
    for (int i = deg[u]; i < deg[u + 1]; i ++) dfs(gra[i]);
}

int cmp(int x, int y)
{
    return dep[x] < dep[y] ? x : y;
}

int ask(int x, int y, int t = 0)
{
    if (x == y) return x;
    x = dfn[x], y = dfn[y];
    if (x > y) swap(x, y);
    int k = 31 ^ __builtin_clz(y - x);
    k = cmp(spt[k][x + (1 << k)], spt[k][y]);
    return t ? k : par[k];
}

array <int, 4> operator+(const array <int, 4> &a, const array <int, 4> &b)
{
    if (a[0] == b[0]) return {a[0], a[1], b[0], b[1]};
    if (a[0] < b[0]) return a[2] <= b[0] ? a : array {a[0], a[1], b[0], b[1]};
    return b[2] <= a[0] ? b : array {b[0], b[1], a[0], a[1]};
}

void build(vector <int> &ver, vector <int> tre[])
{
    sort(ver.begin(), ver.end(), [] (int x, int y) {return dfn[x] < dfn[y];});
    ver.erase(unique(ver.begin(), ver.end()), ver.end());
    for (int i = 1, t = ver.size(); i < t; i ++) ver.push_back(ask(ver[i - 1], ver[i]));
    sort(ver.begin(), ver.end(), [] (int x, int y) {return dfn[x] < dfn[y];});
    ver.erase(unique(ver.begin(), ver.end()), ver.end());
    for (int i = 1; i < ver.size(); i ++) tre[ask(ver[i - 1], ver[i])].push_back(ver[i]);
}

void mdf(int k, int v)
{
    k = tin[k];
    while (k <= tin[0]) bit[k] += v, k += -k & k;
}

int qry(int k)
{
    int l = tin[k] - 1, r = tou[k], v = 0;
    while (l >= 1) v -= bit[l], l -= -l & l;
    while (r >= 1) v += bit[r], r -= -r & r;
    return v;
}

void dfsR(int u)
{
    tin[u] = ++tin[0];
    for (int i = 1; i < L; i ++) anc[u][i] = anc[anc[u][i - 1]][i - 1];
    for (int v : trr[u]) anc[v][0] = u, dfsR(v);
    tou[u] = tin[0];
}

void dfsL(int u)
{
    siz[u] = 1, son[u] = 0;
    for (int v : trl[u])
    {
        dfsL(v), siz[u] += siz[v];
        if (siz[v] > siz[son[u]]) son[u] = v;
    }
}

int get(int u)
{
    if (qry(u)) return u;
    for (int i = L - 1; ~i; i --)
        if (anc[u][i] && !qry(anc[u][i]))
            u = anc[u][i];
    return anc[u][0];
}

void dfs3(int u, int t)
{
    for (int w : trl[u])
        if (w != son[u]) dfs3(w, 1);
    vector <int> cur; cur.swap(ano[u]);
    if (son[u]) dfs3(son[u], 0), ano[u].swap(ano[son[u]]);
    for (int y : cur)
    {
        int v = get(y);
        if (v && dep[u] + dep[v] > tmp) tmp = dep[u] + dep[v], tmpa = u, tmpb = v;
        mdf(y, 1), ano[u].push_back(y);
    }
    for (int w : trl[u])
        if (w != son[u])
            for (int y : ano[w])
            {
                int v = get(y);
                if (v && dep[u] + dep[v] > tmp) tmp = dep[u] + dep[v], tmpa = u, tmpb = v;
                mdf(y, 1), ano[u].push_back(y);
            }
    if (!t) return;
    for (int y : ano[u]) mdf(y, -1);
}

void dfs2(int u)
{
    for (int i = deg[u]; i < deg[u + 1]; i ++)
    {
        int v = gra[i]; dfs2(v);
        tag[u] = tag[u] + tag[v];
    }
    if (tag[u][3] && dep[u] - tag[u][2] > res) res = dep[u] - tag[u][2], resa = tag[u][1], resb = tag[u][3];

    vector <array <int, 3>> cur;
    for (int i : idx[u])
    {
        auto &[x, y] = pat[i];
        int p = ask(x, u, 1), q = ask(y, u, 1);
        if (p > q) swap(p, q), swap(x, y);
        cur.push_back({p, q, i});
    }
    sort(cur.begin(), cur.end());
    for (int l = 0, r = 0; l < cur.size(); l = r)
    {
        while (r < cur.size() && cur[r][0] == cur[l][0] && cur[r][1] == cur[l][1]) ++r;
        vector <int> lft, rgt;
        lft.reserve(r - l + 1), rgt.reserve(r - l + 1);
        lft.push_back(u), rgt.push_back(u);
        for (int i = l; i < r; i ++)
        {
            auto [x, y] = pat[cur[i][2]];
            lft.push_back(x), rgt.push_back(y);
            ano[x].push_back(y);
        }
        build(lft, trl), build(rgt, trr);
        tin[0] = tmp = 0, memset(anc[u], 0, sizeof(anc[u]));
        dfsR(u), dfsL(u), dfs3(u, 1), tmp -= dep[u] * 2;
        if (tmp > ans) ans = tmp, ansa = tmpa, ansb = tmpb;
        for (int i : lft) trl[i].clear(), ano[i].clear();
        for (int i : rgt) trr[i].clear();
    }
}

int main()
{
    n = R(), m = R();
    for (int i = 2; i <= n; i ++) ++deg[par[i] = R()];
    for (int i = 1; i <= n; i ++) deg[i + 1] += deg[i];
    for (int i = 2; i <= n; i ++) gra[--deg[par[i]]] = i;
    for (int i = 1; i <= m; i ++) pat[i][0] = R(), pat[i][1] = R();

    dfs(1);
    for (int i = 1; 1 << i <= n; i ++)
        for (int j = 1 << i; j <= n; j ++)
            spt[i][j] = cmp(spt[i - 1][j - (1 << (i - 1))], spt[i - 1][j]);

    for (int i = 1; i <= n; i ++) tag[i] = {n + 1, 0, n + 1, 0};
    for (int i = 1; i <= m; i ++)
    {
        auto [u, v] = pat[i]; int w = ask(u, v);
        tag[u] = tag[u] + array {dep[w], i, n + 1, 0};
        tag[v] = tag[v] + array {dep[w], i, n + 1, 0};
        idx[w].push_back(i);
    }
    resa = 1, resb = 2;
    dfs2(1);
    if (res < ans)
    {
        res = ans;
        for (int i = 1; i <= m; i ++)
            if (ask(pat[i][0], ansa) == ansa && ask(pat[i][1], ansb) == ansb)
                resb = resa, resa = i;
    }
    cout << res << endl << resa << ' ' << resb << endl;

    IO::flush();
    return 0;
}

::::