Min-Max 容斥

· · 算法·理论

简介:

用最小值推最大值的一种容斥。

基本式子:

Max(S) 表示一个集合内的最大值,Min(S) 为一个集合内的最小值,则有:

Max(S)=\sum\limits_{T\subseteq S}(-1)^{|T|+1}Min(T)

证明如下:
考虑将 S 内的元素降序排序得到序列 a_{1...n},则有:

\begin{aligned}&\sum\limits_{T\subseteq S}(-1)^{|T|+1}Min(T)\\&=\sum\limits_{i=1}^{n}a_i(\sum\limits_{2|j}\dbinom{i-1}{j}-\sum\limits_{2\nmid j}\dbinom{i-1}{j})\\&=\sum\limits_{i=1}^{n}a_i[i-1=0]\\&=a_1\\&=Max(S)\end{aligned}

证毕。

期望意义下拓展:

\begin{aligned}&E(Max(S))\\&=E(\sum\limits_{T\subseteq S}(-1)^{|T|+1}Min(T))\\&=\sum\limits_{T\subseteq S}(-1)^{|T|+1}E(Min(T))\end{aligned}

例题:

P5643
设节点 i 首次到达的时间为 t_i,则要求 E(\max\limits_{i\in S}t_i),直接上 Min-Max 容斥。
那么我们要求 f_S=E(\min\limits_{i\in S}t_i)。可以枚举 S,每次做一遍 DP:设 dp_i 表示目前在 i 节点,之后首次走到某个 S 内节点的期望用时。则 f_S=dp_{root}
不难写出转移式:

dp_i=\begin{cases}\sum\limits_{(i,j)\in E}\frac{1}{d_i}dp_j+1&i\notin S\\0&i\in S\end{cases}

其中 d_i 表示节点度数。
高斯消元!......吗?我们发现转移的依赖关系和原图是一样的,即一棵树,因此我们可以从下到上消元。
具体地,我们可以求出 k_i,b_i 满足 dp_i=k_idp_{fa_i}+b_i。我们可以在树上深搜一遍,当做完 i 的所有子树的时候我们可以得到 i 的所有儿子 ji 的关系,那么直接带入消元,然后我们就又得到了 ifa_i 之间的关系。这样的时间复杂度是 O(n) 的。一些细节见代码。
最后我们还要求 ans_S=\sum\limits_{T\subseteq}S(-1)^{|T|+1}f_T
我们可以先将 f_S 都乘上自己的容斥系数,然后做一遍高维前缀和,详见代码。
总时间复杂度 O(n2^n+kQ)

代码:

#include <bits/stdc++.h>
using namespace std;
const int N = 20,M = (1 << 18) + 5,mod = 998244353;
int n,q,r,f[M],pc[M],k[N],b[N];
vector <int> v[N];
int ksm(int x,int y)
{
    int ret = 1;
    while(y)
    {
        if(y & 1) ret = 1LL * ret * x % mod;
        x = 1LL * x * x % mod,y >>= 1;
    }
    return ret;
}
void dfs(int x,int fa,int s)
{
    k[x] = b[x] = 1;
    int p = ksm(v[x].size(),mod - 2),inv;
    for(int y : v[x])
    {
        if(y == fa) continue;
        dfs(y,x,s);
        k[x] = (k[x] - 1LL * p * k[y] % mod + mod) % mod;
        b[x] = (b[x] + 1LL * p * b[y] % mod) % mod;
    }
    if(s & (1 << (x - 1))) k[x] = b[x] = 0;
    else
    {
        inv = ksm(k[x],mod - 2);
        if(fa) k[x] = 1LL * p * inv % mod;
        else k[x] = 0;
        b[x] = 1LL * b[x] * inv % mod;
    }
}
int main()
{
    scanf("%d%d%d",&n,&q,&r);
    for(int i = 1,a,b;i < n;i++)
    {
        scanf("%d%d",&a,&b);
        v[a].push_back(b);
        v[b].push_back(a);
    }
    for(int i = 1;i < (1 << n);i++)
    {
        pc[i] = (pc[i >> 1] ^ (i & 1));
        dfs(r,0,i);
        f[i] = (pc[i] ? b[r] : mod - b[r]);
    }
    for(int i = 0;i < n;i++)
        for(int j = 1;j < (1 << n);j++)
            if(j & (1 << i)) f[j] = (f[j] + f[j ^ (1 << i)]) % mod;
    for(int i = 1,s,k,p;i <= q;i++)
    {
        s = 0;
        scanf("%d",&k);
        while(k--)
        {
            scanf("%d",&p);
            s |= (1 << (p - 1));
        }
        printf("%d\n",f[s]);
    }
    return 0;
}