CF1467E Distinctive Roots in a Tree

· · 个人记录

CF1467E Distinctive Roots in a Tree

突然开始很好奇,这么简单的题是怎么卡我一个下午的。

直接正着考虑感觉不太好确定,我们不妨考虑不同颜色对答案产生的影响。

因为只需要考虑颜色对答案的影响,本质上钦定任意点为根都是没有问题的。

我们考虑相同颜色的节点 u,v

如果 uv 的祖先,那么可以选的点,只有在 u,v 之间的所有点,以及他们的子树。

如果 u 不是 v 的祖先,那么不能选的点是 u,v 的子树。

可以设 f_i 表示在以 i 为根的子树内,有几种颜色不满足选该子树的条件。

设当前遍历到的节点为 p

如果 p 子树里面的 a_p 个数不是全部的 a_p 其实意味着,肯定有别的 a_p 在外面,那么其子树肯定是不能选的。

别忘了,当前节点的值是 a_p !!

我们考虑在 dfs 的时候,计算子树里面有没有与当前节点相同的颜色,如果有的话,那么当前节点到根都是不行的。

我们不用考虑儿子节点,因为其肯定会被优先标记。

我们计算答案的时候可以使用树上差分。

#include <bits/stdc++.h>
using namespace std;

//#define Fread

#ifdef Fread
char buf[1 << 21], *iS, *iT;
#define gc() (iS == iT ? (iT = (iS = buf) + fread (buf, 1, 1 << 21, stdin), (iS == iT ? EOF : *iS ++)) : *iS ++)
#endif // Fread

template <typename T>
void r1(T &x) {
    x = 0;
    char c(getchar());
    int f(1);
    for(; c < '0' || c > '9'; c = getchar()) if(c == '-') f = -1;
    for(; '0' <= c && c <= '9';c = getchar()) x = (x * 10) + (c ^ 48);
    x *= f;
}

template <typename T,typename... Args> inline void r1(T& t, Args&... args) {
    r1(t);  r1(args...);
}

const int maxn = 2e5 + 5;
const int maxm = maxn << 1;

int n, m;
int head[maxn], cnt;
struct Edge {
    int to, next;
}edg[maxn << 1];
void add(int u,int v) {
    edg[++ cnt] = (Edge) {v, head[u]}, head[u] = cnt;
}

map<int,int> sum, now;
#define rep for(int i = head[p];i;i = edg[i].next)
int a[maxn], f[maxn];
void dfs(int p,int fa) {
    int last = now[a[p]];
    ++ now[a[p]];
    rep {
        int to = edg[i].to;
        if(to == fa) continue;
        int last1 = now[a[p]];
        dfs(to, p);
        if(last1 != now[a[p]]) ++ f[1], -- f[to];
    }
    if(now[a[p]] - last != sum[a[p]]) ++ f[p];
}

int ans;

void dfs1(int p,int fa,int now) {
    if(!now) ++ ans;
    rep {
        int to = edg[i].to;
        if(to == fa) continue;
        dfs1(to, p, now + f[to]);
    }
}
/*
we can set the statue f(i) means all subtrees of i, how many color can't satisfy the rule

we can use difference to solve it
*/
signed main() {
//    freopen(".in", "r", stdin);
//    freopen(".out", "w", stdout);
    int i, j;
    r1(n);
    for(i = 1; i <= n; ++ i) r1(a[i]), sum[a[i]] ++;
    for(i = 1; i < n; ++ i) {
        int u, v; r1(u, v);
        add(u, v), add(v, u);
    }
    dfs(1, 0);
    dfs1(1, 0, f[1]);
    printf("%d\n", ans);
    return 0;
}