P2664 线段树合并题解

· · 题解

先考虑一个点起始怎么做,再换根。显然可以 DFS,但是不够数据结构化,不好换根。考虑每种颜色统计有多少点使 $s$ 到该点经过该种颜色。 使用线段树合并。点 $u$ 线段树上 $T_i=j$ 表示 $u$ 子树上所有点中,有 $j$ 个点到 $u$ 的路径中经过颜色 $i$。此时点 $u$ 在其子树内的答案为 $\sum T_i$。合并时有($sz_u$ 表示以 $u$ 为根的子树的大小): $$ {T_u}_i=\begin{cases} \sum {T_v}_i(v 是 u 的儿子) && i\neq c_u\\ sz_u && i = c_u \end{cases} $$ 即如图: ![](https://cdn.luogu.com.cn/upload/image_hosting/8hfsqwbt.png) $u$ 开头的路径中,其余颜色出现的次数不变,而 $u$ 到子树内任一点都要经过 $c_u$。 现在求出了以某一点为根的答案,考虑换根。考察如图的情况: ![](https://cdn.luogu.com.cn/upload/image_hosting/xn0qqvb2.png) 已知 $f$ 为根时的答案,从 $f$ 转移到 $u$ 的过程中,有: $T^\prime_i=\begin{cases}T_i && i \neq c_u 且i\neq c_f\\ n && i =c_u\\ n - sz_u + {{T_u}_c}_f && i=c_f \end{cases}

第三种情况中,T_u 为线段树合并得到的 u 的线段树。此时意为,f 以外的部分,路径必经过 c_f,而以内的部分已经求过,线段树合并时将这个值记录即可。

线段树合并 \mathcal O(n \log n),而每次转移是 \mathcal O(\log n) 的,总复杂度 \mathcal O(n \log n)

#include<bits/stdc++.h>
#define debug(x) cerr << #x << ": " << x << '\n';
#define rep(i, a, b) for (int i = (a); i <= (b); i++)
#define lop(i, a, b) for (int i = (a); i < (b); i++)
#define dwn(i, a, b) for (int i = (a); i >= (b); i--)
#define elif else if
#define iosfst ios::sync_with_stdio(0);cin.tie(0), cout.tie(0)
#define pb push_back
#define _if (
#define _then ?(
#define _els ):
#define _end )
#define intt long long
using namespace std;
#define N 100005
bool mem1;
vector<int>g[N];
int n, fa[N], a[N], sz[N];intt val[N], ans[N];
namespace sgt{
    int tp, ls[N * 32], rs[N * 32], fa[N * 32];intt tr[N * 32];
    int rt[N * 32];
    inline void pushup(int p) {
        tr[p] = tr[ls[p]] + tr[rs[p]];
    }
    inline void upd(int p, int l, int r, int x, intt v) {
        if(l == r) {
            tr[p] += v;
            return ;
        }
        int mid = (l + r) >> 1;
        if(x <= mid) {
            if(!ls[p]) {
                ls[p] = ++ tp;
                fa[tp] = p;
            }
            upd(ls[p], l, mid, x, v);
        }
        else {
            if(!rs[p]) {
                rs[p] = ++ tp;
                fa[tp] = p;
            }
            upd(rs[p], mid + 1, r, x, v);
        }
        pushup(p);
    }
    inline void set(int p, int l, int r, intt x, intt v) {
        if(l == r) {
            tr[p] = v;
            return ;
        }
        int mid = (l + r) >> 1;
        if(x <= mid) {
            if(!ls[p]) {
                ls[p] = ++ tp;
                fa[tp] = p;
            }
            set(ls[p], l, mid, x, v);
        }
        else {
            if(!rs[p]) {
                rs[p] = ++ tp;
                fa[tp] = p;
            }
            set(rs[p], mid + 1, r, x, v);
        }
        pushup(p);
        //cout << "Set " << x << " " << v << "\n";
        //cout << p << " " << l << " " << r << " " << tr[p] << "\n";
    }
    inline int merge(int p1, int p2) {
        if(!p1 || !p2) return p1 | p2;
        ls[p1] = merge(ls[p1], ls[p2]);
        rs[p1] = merge(rs[p1], rs[p2]);
        tr[p1] += tr[p2];
        return p1;
    }
    inline int qry(int p, int l, int r, intt x) {
        if(l == r) 
            return tr[p];
        int mid = (l + r) >> 1;
        if(x <= mid) {
            if(!ls[p]) 
                return 0;
            return qry(ls[p], l, mid, x);
        }
        else {
            if(!rs[p]) 
                return 0;
            return qry(rs[p], mid + 1, r, x);
        }
    }
}
bool mem2;

void dfs(int u, int f) {
    fa[u] = f;
    sz[u] = 1;
    for(auto v : g[u])
        if(v != f)
            dfs(v, u),
            sgt::rt[u] = sgt::merge(sgt::rt[u], sgt::rt[v]),
            sz[u] += sz[v];
    if(sgt::rt[u] == 0)
        sgt::rt[u] = ++sgt::tp;
    sgt::set(sgt::rt[u], 1, 1000000, a[u], sz[u]);  
    if(f)
        val[u] = sgt::qry(sgt::rt[u], 1, 1000000, a[f]);  
}

void dfs2(int u, int f) {
    //cout << "DFS2 " << u << " " << f << "\n";
    if(u == 1) {
        for(auto v : g[u])
            dfs2(v, u);
        return ;
    }
    intt rbp = sgt::qry(sgt::rt[1], 1, 1000000, a[u]);
    intt rdx = sgt::qry(sgt::rt[1], 1, 1000000, a[f]);
    sgt::set(sgt::rt[1], 1, 1000000, a[u], n);
    sgt::set(sgt::rt[1], 1, 1000000, a[f], n - sz[u] + val[u]);
    ans[u] = sgt::tr[sgt::rt[1]];
    for(auto v: g[u])
        if(f != v)
            dfs2(v, u);
    sgt::set(sgt::rt[1], 1, 1000000, a[u], rbp);
    sgt::set(sgt::rt[1], 1, 1000000, a[f], rdx);
}

signed main() {
    iosfst;
    cin >> n;
    rep(i, 1, n) cin >> a[i];
    lop(i, 1, n) {
        int x, y;
        cin >> x >> y;
        g[x].pb(y);
        g[y].pb(x);
    }
    dfs(1, 0);
    ans[1] = sgt::tr[sgt::rt[1]];
    dfs2(1, 0);
    rep(i, 1, n)
        cout << ans[i] << "\n";
}