P3714 [BJOI2017] 树的难题

· · 题解

本题存在单 \log 做法。

我们考虑点分治,选择重心 x 作为根。

假设每条边有一个权值,权值定义为每条边的权值和,我们的做法是依次遍历所有儿子的子树,计算出子树内 g(d) 表示深度为 d 的最大权值和是多少,同时维护前面的 f(d) 表示前面的子树中深度为 d 的最大是多少。

计算答案可以从大到小扫一遍 g,然后将查询 [L-i,R-i]f 的最大值,可以用单调队列维护。

注意到每次统计答案的时间复杂度和 g,f 的长度相关,我们可以将所有儿子按照子树深度从小到大排序,然后从左往右扫描,设 d_i 表示排序后第 i 子树的深度,则我们的时间复杂度是 O(\sum(d_i+d_{i+1})),而由于 d_i \le d_{i+1},所以实际上就是 O(\sum d_i),而 \sum d_i 的上界是当前分治的节点个数,所以这样扫描使得每一轮的时间复杂度是 O(n),点分治就是 O(n \log n)

注意到还需要将子树按照深度排序,这样看似是 O(n \log^2 n) 的,实际上这是 O(n \log n)

我们每次需要排序的长度是子树的个数,而每个子树都会继续分治。

而我们知道分治的总次数是 O(n) 的,所以实际上这些排序的长度和也是 O(n) 的,所以总时间复杂度是 O(n \log n) 的。

好了,现在就把点分治部分讲完了,回到原问题。

我们不用更改上面的框架,我们思考如何处理这个颜色段。

首先,我们还是可以算出从 x 到每个点的路径的权值。

但现在问题在于两个路径的合并有两种:从 x 出发的两条路径,如果第一条边颜色相同,还要减去这个颜色的权值一次。

我们考虑现在有一个从 x 出发的路径 A,他希望前面找一条路径合并使得权值尽量大,那么只能是以下两种:

这是后可以不用线段树等数据结构,这种问题有一个很简单的方法:记录最大值和次大值。

我们考虑记录前面所有路径的最大值和次大值,要求最大值和次大值的颜色不能相同。

那么如果 A 和最大值不同色,那么它肯定选择和最大值合并,这是显然的。

如果 A 和最大值同色,那么它一定和次大值不同色,那么最大值就是权值最大的同色路径,次大值就是权值最大的异色路径,我们比较以下而这即可。

所以我们用 f 数组来记录最大值和次大值即可,剩下部分没有区别,只是记录的信息变了一下。

这样就能做到严格 O(n \log n) 了。这个记录最大值和次大值的技巧实际上和树形 dp 求直径有点像。

代码实现和重建计划,Freezing with Style这两个问题几乎一样。

#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 2e5 + 5;
const int inf = -2e9;
typedef pair<int, int> pii;
typedef pair<pii, pii> pp;
#define debug(x) cout << #x << "=" << x << endl
#define mp(x, y) make_pair(x, y)
#define fi first
#define se second
#define all(x) x.begin(), x.end()

int n, m, L, R;
struct Edge {
    int to, val;
    Edge (int _to = 0, int _val = 0) :
        to(_to), val(_val) {}
};
vector<Edge> e[N];
int c[N] = {0};
bool vis[N] = {false};
int sz[N] = {0};
int getsz(int x, int pr) {
    sz[x] = 1;
    for (auto i: e[x])
        if (i.to != pr && !vis[i.to])
            sz[x] += getsz(i.to, x);
    return sz[x];
}
int totsz, mxsz, rt;
void getrt(int x, int pr) {
    int msz = totsz - sz[x];
    for (auto i: e[x])
        if (i.to != pr && !vis[i.to]) {
            getrt(i.to, x);
            msz = max(msz, sz[i.to]);
        }
    if (mxsz > msz)
        mxsz = msz, rt = x;
}
int getd(int x, int pr) {
    int mxd = 0;
    for (auto i: e[x])
        if (i.to != pr && !vis[i.to]) 
            mxd = max(mxd, getd(i.to, x) + 1);
    return mxd;
}

pp f[N];
int g[N] = {0};
void add(pp &x, pii y) {
    if (y.fi > x.fi.fi) {
        if (x.fi.se != y.se)
            x.se = x.fi, x.fi = y;
        else
            x.fi = y;
    }
    else if (x.se.fi < y.fi && x.fi.se != y.se) 
        x.se = y;
}

void upd(int x, int pr, int d, int tv, int cl, int cfr) {
    g[d] = max(g[d], tv);
    for (auto i: e[x])
        if (i.to != pr && !vis[i.to])
            upd(i.to, x, d + 1, tv + (cfr != i.val) * c[i.val], cl, i.val);
}

int ans = 0;

int q[N] = {0};
int cal(pp x, int cl) {
    if (x.fi.se != cl)
        return x.fi.fi;
    return max(x.fi.fi - c[cl], x.se.fi);
}
void calc(int _n, int _m, int cl) {
/*  debug(_n);
    for (int i = 0; i <= _n; i++)
        printf("[(%d, %d), (%d, %d)], ", f[i].fi.fi, f[i].fi.se, f[i].se.fi, f[i].se.se);
    printf("\n");
    debug(_m);
    for (int i = 1; i <= _m; i++)
        printf("%d, ", g[i]);
    printf("\n");*/
    int l = 0, r = 0;
    for (int i = _m, j = 0; i >= 1; i--) {
        while (j <= _n && j <= R - i) {
            while (l < r && cal(f[q[r - 1]], cl) < cal(f[j], cl))
                r--;
            q[r++] = j++;
        }
        while (l < r && q[l] < L - i)
            l++;
        if (l < r)
            ans = max(ans, g[i] + cal(f[q[l]], cl));
    }
}

void slv(int x) {
    getsz(x, 0);
    totsz = sz[x], mxsz = 2e9, rt = 0;
    getrt(x, 0);
    x = rt;
    vector<pair<int, pii> > res;
    for (auto i: e[x])
        if (!vis[i.to])
            res.push_back(mp(getd(i.to, x) + 1, mp(i.to, i.val)));
    sort(all(res));
    int len = 0;
    f[0] = mp(mp(0, 0), mp(0, 0));

//  debug(x);

    for (auto j: res) {
        int mxd = j.fi, u = j.se.fi, w = j.se.se;
        for (int i = 1; i <= mxd; i++)
            g[i] = inf;
    //  debug(u);
        upd(u, x, 1, c[w], w, w);
        calc(len, mxd, w);
        for (int i = 1; i <= mxd; i++) {
            if (i <= len)
                add(f[i], mp(g[i], w));
            else
                f[i] = mp(mp(g[i], w), mp(inf, 0));
        }
        len = mxd;
    }

    vis[x] = true;
    for (auto i: res)
        slv(i.se.fi);
}

int main() {
    scanf("%d%d%d%d", &n, &m, &L, &R);
    for (int i = 1; i <= m; i++)
        scanf("%d", &c[i]);
    for (int i = 1, u, v, w; i < n; i++) {
        scanf("%d%d%d", &u, &v, &w);
        e[u].push_back(Edge(v, w));
        e[v].push_back(Edge(u, w));
    }
    ans = -2e9;
    slv(1);
    printf("%d\n", ans);
    return 0;
}