题解 P4180 【【模板】严格次小生成树[BJWC2010]】

· · 题解

蒟蒻的广告:http://www.neptuuz.com/wordpress/?p=296

思路上面的dalao们都已经讲的非常清楚了,不过好像没有人用树剖做?那这里就发一下树剖的题解好了

线段树维护最大值m1和严格次大值m2 求最大值直接两个最大值比较即可 求严格次大值要l.m1, l.m2, r.m1, r.m2一起比较才可 要注意先更新m2才能更新m1

不过会被卡常,开个O2就稳过了、、、 Code:

#include<stdio.h>
#include<algorithm>
#include<iostream>
#include<string.h>
#define MAXN 100010
#define MAXM 300010
#define pos(l, r) ((l+r) | (l != r))
using namespace std;

struct Edge {
    int u, v, w, f, next;
    bool operator < (const Edge &A) const { return w < A.w; }
} e[MAXM*2];

struct node {
    int m1, m2;
} t[MAXN*2];

int n, m, h[MAXN], dep[MAXN], son[MAXN], w[MAXN], tot, fa[MAXN], par[MAXN], cnt, id[MAXN], top[MAXN], a[MAXN], TOT;
long long MST, secMST = 1e15;
inline int max(int a, int b) {
    return a > b ? a : b;
}
inline int secmax(int a, int b, int c, int d) {
    int tmp[4] = {a, b, c, d};
    sort(tmp, tmp+4);
    for (int i = 2; i >= 0; --i) {
        if (tmp[i] != tmp[i+1]) return tmp[i];
    }
}

void addEdge(int ui, int vi, int wi, int fi) {
    e[++tot] = (Edge) {ui, vi, wi, fi, h[ui]};
    h[ui] = tot;
}

int find(int x) {
    return fa[x] == x ? x : fa[x] = find(fa[x]);
}

void dfs(int u) {
    w[u] = 1;
    for (int i = h[u]; i; i = e[i].next) {
        if (!w[e[i].v]) {
            dep[e[i].v] = dep[u]+1;
            par[e[i].v] = u;
            a[e[i].v] = e[i].w;
            dfs(e[i].v);
            w[u] += w[e[i].v];
            if (w[son[u]] < w[e[i].v]) son[u] = e[i].v;
        }
    }
}

void init(int u, int p) {
    id[u] = ++cnt;
    top[u] = p;
    if (son[u]) init(son[u], p);
    for (int i = h[u]; i; i = e[i].next) {
        if (!top[e[i].v]) init(e[i].v, e[i].v);
    }
}

void modify(int l, int r, int x, int d, int p) {
    if (l == r) {
        t[p].m1 = d;
        t[p].m2 = 0;
    } else {
        int mid = (l+r)>>1, lc = pos(l, mid), rc = pos(mid+1, r);
        if (x <= mid) modify(l, mid, x, d, lc);
        else modify(mid+1, r, x, d, rc);
        t[p].m2 = secmax(t[lc].m1, t[lc].m2, t[rc].m1, t[rc].m2);
        t[p].m1 = max(t[lc].m1, t[rc].m1);
    }
}

node query(int l, int r, int x, int y, int p) {
    if (x <= l && r <= y) return t[p];
    int mid = (l+r)>>1, lc = pos(l, mid), rc = pos(mid+1, r);
    if (x <= mid && y > mid) {
        node t1 = query(l, mid, x, y, lc), t2 = query(mid+1, r, x, y, rc);
        return (node) {max(t1.m1, t2.m1), secmax(t1.m1, t1.m2, t2.m1, t2.m2)};
    } else if (x <= mid) return query(l, mid, x, y, lc);
    else return query(mid+1, r, x, y, rc);
}

node solve(int u, int v) {
    int pu = top[u], pv = top[v];
    node res = (node) {0, 0};
    while (pu != pv) {
        if (dep[pu] < dep[pv]) {
            swap(pu, pv);
            swap(u, v);
        }
        node tmp = query(1, n, id[pu], id[u], pos(1, n));
        res.m2 = secmax(res.m1, res.m2, tmp.m1, tmp.m2);
        res.m1 = max(res.m1, tmp.m1);
        u = par[pu];
        pu = top[u];
    }
    if (u == v) return res;
    if (dep[u] < dep[v]) swap(u, v);
    node tmp = query(1, n, id[v]+1, id[u], pos(1, n));
    res.m2 = secmax(res.m1, res.m2, tmp.m1, tmp.m2);
    res.m1 = max(res.m1, tmp.m1);
    return res;
}

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1, ui, vi, wi; i <= m; ++i) {
        scanf("%d%d%d", &ui, &vi, &wi);
        addEdge(ui, vi, wi, 0);
    }
    sort(e+1, e+tot+1);
    TOT = tot;
    memset(h, 0, sizeof(h));
    for (int i = 1; i <= n; ++i) fa[i] = i;
    for (int i = 1, k; i <= TOT; ++i) {
        int ux = find(e[i].u), uy = find(e[i].v);
        if (ux != uy) {
            fa[ux] = uy;
            MST += e[i].w;
            e[i].f = 1;
            e[i].next = h[e[i].u];
            h[e[i].u] = i;
            addEdge(e[i].v, e[i].u, e[i].w, 1);
            k++;
        }
        if (k == n-1) break;
    }
    dfs(1);
    init(1, 1);
    for (int i = 1; i <= n; ++i) modify(1, n, id[i], a[i], pos(1, n));
    for (int i = 1; i <= TOT; ++i) {
        if (!e[i].f) {
            node cross = solve(e[i].u, e[i].v);
            long long tmp = MST+e[i].w-(cross.m1 == e[i].w ? cross.m2 : cross.m1);
            if (tmp > MST && tmp < secMST) secMST = tmp;
        }
    }
    printf("%lld\n", secMST);
    return 0;
}