[数学记录]AT5143 [AGC035F] Two Histograms

command_block

2021-01-15 08:36:07

Personal

**题意** : 有一个 $n\times m$ 的数组,初始时全 $0$。 对每一行每一列,选出某个前缀将其 $+1$。显然,最终得到的矩阵中只会有 $0,1,2$。 求最终能得到多少本质不同的矩阵。 $n,m\leq 5\times 10^5$ ,时限$\texttt{2s}$。 ------------ 题面相当于给了我们一个映射 : 行列操作前缀长度 $\rightarrow$ 矩阵 我们只能直接知道行列操作的性质,直接统计矩阵较为困难。 分析映射计数无非两条路子,一是找条件,二是找反射。 一般而言,找反射(尤其是双射)较为困难,先从必要条件入手。 先来研究如何构造结果相同的两个不同的操作序列,将操作序列集合化简。 设 $h_0[i]$ 为第 $i$ 行选出的前缀长度,$h_1[i]$ 为列。 不难发现,行列内部无影响,若要相同,只能一行一列之间替换,即 : $h_0[x]=y,h_1[y]=x-1\leftrightarrow h_0[x]=y-1,h_1[y]=x$。(一列伸长,一行缩短) 那不妨将所有能伸缩的位置都让行伸出来,一番调整之后可以在结果相同的前提下做到 $h_0[x]=y-1,h_1[y]=x$ 即 $h_0\big[h_1[i]\big]=i-1$ (称之为不好的拐角)不存在。称这样的操作序列是好的。 我们成功地将操作序列集合化简了,即我们找到了映射 :操作序列 $\rightarrow$ 好的操作序列 接下来就是分析映射 :好的操作序列 $\rightarrow$ 矩阵 我们猜想这恰是一个双射,证明如下 : 假设两组不同的好的操作序列为 $h,h'$ ,且结果相同。 首先找出第一个不等的列 $h_1[i]≠h_1'[i]$,不妨设 $h_1[i]<h_1'[i]$。 考虑位置 $(h_1'[i],i)$ ,显然,它只有可能为 $1$。 则可以推出 $h_0\big[h_1'[i]\big]\geq i,h_0'\big[h_1'[i]\big]<i$。 同时,由好操作序列,推出 $h_0'\big[h_1'[i]\big]≠i-1$ (否则这个行就可以进一步伸长),所以 $h_0'\big[h_1'[i]\big]\leq i-2$. 若 $i=1$ ,则有 $h_0'\big[h_1'[i]\big]\leq -1$ ,矛盾。 若 $i>1$,由于 $h_0'\big[h_1'[i]\big]\leq i-2$ ,所以有 $h_0'\big[h_1'[i]\big]<j-1<h_0\big[h_1'[i]\big]$ 考虑位置 $(h_1'[i],i-1)$ ,其被 $h_0$ 覆盖了,但是没有被 $h_0'$ 覆盖。 又因为 $h_1[i-1]=h_1'[i-1]$ (因为 $i$ 是第一个不等的),列的覆盖情况相同,所以该位置必然不等。 通过这个双射,我们把问题转化成了 : 求好操作序列的个数。 考虑容斥,钦定 $k$ 个不好的拐角,方案数为 : $f(k)=\dbinom{n}{k}\dbinom{m}{k}k!(n+1)^{m-k}(m+1)^{n-k}$ $\dbinom{n}{k}\dbinom{m}{k}$ 为选出行列的方案数,$k!$ 为行列匹配的方案数,构造不好的拐角的方案数恒为 $1$,后面的 $(n+1)^{m-k}(m+1)^{n-k}$ 是其余部分随意的方案数。 则 ${\rm Ans}=\sum\limits_{k=0}^{\min(n,m)}(-1)^kf(k)$。 ```cpp #include<algorithm> #include<cstdio> #define ll long long #define MaxN 500500 using namespace std; const int mod=998244353; ll powM(ll a,int t=mod-2){ ll ret=1; while(t){ if (t&1)ret=ret*a%mod; a=a*a%mod;t>>=1; }return ret; } ll fac[MaxN],ifac[MaxN]; ll C(int n,int m) {return fac[n]*ifac[m]%mod*ifac[n-m]%mod;} void Init(int n) { fac[0]=1; for (int i=1;i<=n;i++) fac[i]=fac[i-1]*i%mod; ifac[n]=powM(fac[n]); for (int i=n;i;i--) ifac[i-1]=ifac[i]*i%mod; } int n,m; int main() { scanf("%d%d",&n,&m); Init(max(n,m)); ll ans=0; for (int k=0;k<=min(n,m);k++){ ll buf=C(n,k)*C(m,k)%mod*fac[k]%mod*powM(n+1,m-k)%mod*powM(m+1,n-k)%mod; if (k&1)ans-=buf;else ans+=buf; }printf("%lld\n",(ans%mod+mod)%mod); return 0; } ```