题解:P13275 [NOI2025] 集合

· · 题解

退役好久没做题所以也没更新. 最近放假在家摆烂摸鱼, 刷 b 站看到某位主播 vp 今年 NOI, 发现还有这种魔怔数数题, 就来做着玩玩. 训 AtCoder 取模计数题的人会有好报的. 😊

给定正整数 n 和序列 a_0,\ldots,a_{2^n-1}\in\mathbb F_p, 对 S\subseteq [0,2^n)\cap\mathbb Z 定义 \mathrm{And}(S)S 中所有数的按位与, 要求:

\sum_{S,T\subseteq[0,2^n)}\mathbf 1\{S\cap T=\varnothing\}\cdot\mathbf 1\{\mathrm{And}(S)=\mathrm{And}(T)\}\prod_{i\in S\cup T}a_i.

数据范围: n\le 20p=998\,244\,353, 其中部分测试点满足 a_i\ne -1.

考虑对 \mathbf 1\{\mathrm{And}(S)=\mathrm{And}(T)\} 容斥: 枚举 x,y\in[0,2^n), 设 S_x 表示包含 x 的数的集合, 条件 x\subseteq\mathrm{And}(S)y\subseteq\mathrm{And}(T) 等价于 S\subseteq S_xT\subseteq S_y, 又因为

\mathbf 1\{\mathrm{And}(S)=\mathrm{And}(T)\}=\sum_{x\subseteq\mathrm{And}(S)}\sum_{y\subseteq\mathrm{And}(T)}(-1)^{\mathrm{pc}(x)+\mathrm{pc}(y)}\cdot 2^{\mathrm{pc}(x\&y)},

代回原式即得容斥系数是 (-1)^{\mathrm{pc}(x)+\mathrm{pc}(y)}\cdot 2^{\mathrm{pc}(x\&y)}, 贡献是 \sum_{S\subseteq S_x}\sum_{T\subseteq S_y}\mathbf 1\{S\cap T=\varnothing\}\prod_{i\in S\cup T}a_i.

这里 \mathrm{pc}(x) 表示 x 的二进制表示中 1 的数量. 容斥系数怎么算的? 将条件表示为每个二进制位的 "相等" 条件相乘, 权值 ((1,0),(0,1)) 作差分得到 ((1,-1),(-1,2)), 按分配律全部展开就好啦.

贡献式里每个元素的方案独立: S_x\cap S_y=S_{x|y} 的元素可以放入 ST 或不放, 方案数为 1+2a_i; 其他 (S_x\cup S_y)\setminus S_{x|y} 的元素则是 1+a_i.

又因为 \mathrm{pc}(x|y)=\mathrm{pc}(x)+\mathrm{pc}(y)-\mathrm{pc}(x\&y), 我们将其化为 OR 卷积的形式: 设 s_x:=\prod_{i\supseteq x}(1+a_i), t_x:=\prod_{i\supseteq x}\frac{1+2a_i}{(1+a_i)^2}, 答案是 (-2)^{\mathrm{pc}(x)}s_x 卷积自己, 逐项乘上 2^{-\mathrm{pc}(x)}t_x 再求和.

如果有 a_i=-1 咋办呢? 我们要把 s_xs_y 中属于 t_{x|y} 的 0 因子剔除. 考虑 FMT 的过程, 什么情况下 s_xs_y 的卷积项会贡献到 z 的位置? 只有 z 包含 xy 的情况. 此时 t_z 的分母必定整除 s_xs_y, 即求和的每一项都被 t_z 的分母整除. 所以只需要维护 0 次数最低的一项, 就算有 a+(-a) 相消也不用关心更高次项的系数, 肯定是用不到的.

随便看了一下其他题解, 感觉这部分讲得比较模糊, 所以在这里也传了一份.

总结一下, 用后缀和/积计算 s_xt_x 的 0 次数和系数, 再跑一个 OR 卷积就完成了, 时间复杂度 O(n2^n).

#include <bits/stdc++.h>
#define fi first
#define se second
#define pc __builtin_popcount
typedef std::pair<int, int> PII; // coefficient and degree of 0
typedef long long LL;

const int mod = 998244353;
int qmo(int x) { return x + (x >> 31 & mod); }
int ksm(int a, int b) {
  int res = 1;
  for (; b; b >>= 1, a = (LL)a * a % mod)
    if (b & 1) res = (LL)res * a % mod;
  return res;
}

PII operator + (const PII &a, const PII &b) {
  if (a.se == b.se) return PII(qmo(a.fi + b.fi - mod), a.se);
  return a.se < b.se ? a : b;
}
PII operator += (PII &a, const PII &b) { a = a + b; return a; }

PII operator - (const PII &a) { return PII(qmo(-a.fi), a.se); }
PII operator - (const PII &a, const PII &b) { return a + (-b); }
PII operator -= (PII &a, const PII &b) { a = a - b; return a; }

PII operator * (const PII &a, int b) {
  if (b == 0) return PII(a.fi, a.se + 1);
  return PII((LL)a.fi * b % mod, a.se);
}
PII operator *= (PII &a, int b) { a = a * b; return a; }

PII operator * (const PII &a, const PII &b) {
  return PII((LL)a.fi * b.fi % mod, a.se + b.se);
}
PII operator *= (PII &a, const PII &b) { a = a * b; return a; }

void solve() {
  int n; std::cin >> n;
  std::vector<int> pwn2(n+1), pwi2(n+1); // power of -2 and 1/2
  pwn2[0] = pwi2[0] = 1;
  for (int i = 1; i <= n; ++i) {
    pwn2[i] = pwn2[i-1] * (mod - 2ll) % mod;
    pwi2[i] = (pwi2[i-1] + (pwi2[i-1] & 1) * mod) / 2;
  }

  int N = 1 << n; std::vector<PII> s(N), t(N);
  for (int i = 0, x; i < N; ++i) {
    std::cin >> x;
    if (x == mod - 1) {
      s[i] = PII(1, 1);
      t[i] = PII(mod - 1, -2);
    } else {
      s[i] = PII(1 + x, 0);
      t[i] = PII((1 + 2ll * x) * ksm(1 + x, mod - 3) % mod, 0);
    }
  }

  for (int i = 0; i < n; ++i)
    for (int j = 0; j < N; ++j)
      if (!(j & (1 << i))) {
        s[j] *= s[j ^ (1 << i)];
        t[j] *= t[j ^ (1 << i)];
      }

  for (int i = 0; i < N; ++i) {
    s[i] *= pwn2[pc(i)];
    t[i] *= pwi2[pc(i)];
  }

  for (int i = 0; i < n; ++i)
    for (int j = 0; j < N; ++j)
      if (j & (1 << i)) s[j] += s[j ^ (1 << i)];
  for (int i = 0; i < N; ++i) s[i] *= s[i];
  for (int i = 0; i < n; ++i)
    for (int j = 0; j < N; ++j)
      if (j & (1 << i)) s[j] -= s[j ^ (1 << i)];

  int ans = 0;
  for (int i = 0; i < N; ++i) {
    PII cur = s[i] * t[i]; assert(cur.se >= 0);
    if (cur.se == 0) ans = qmo(ans + cur.fi - mod);
  }
  printf("%d\n", ans);
}

int main() {
  int _, t;
  std::ios::sync_with_stdio(false);
  std::cin >> _ >> t;
  while (t --) { solve(); }
}