P3824 [NOI2017]泳池

斯德哥尔摩

2018-07-30 12:18:08

Personal

[P3824 [NOI2017]泳池](https://www.luogu.org/problemnew/show/P3824) 这里有一篇大佬的[博客](https://www.luogu.org/blog/ShadowassIIXVIIIIV/solution-p3824)。 我就是看这个看懂的。(我好菜啊。。。) 首先设$ans_i$是安全面积$S<=i$的概率,那么问题可以转成: 求$ans_k-ans_{k-1}$。 注意到如果一个竖行$x$满足$A_{x,1}$是危险的话,这一行就没用。 然后连续的$A_{x,1}$是安全的必须不超过$K$个,一列最高连续不超过$K$。 我们可以以此为划分。 设$f[i][j]$为一个宽度为$i$的区域,其中最低的一根柱子恰好是$j$的方案数。 设$g[i][j]$表示一个宽度为$i$的区域,其中最低的一根柱子$>=j$的方案数。 那么枚举最靠左的一根恰好是$j$的柱子的位置,可以得到: $$f[i][j]=\sum_{k=1}^iq^j*(1-q)*g[k-1][j+1]* g[i-k][j]$$ 注意到这个$dp$的复杂度目测是$K^3$。 但$i* j<=K$,所以时间复杂度应为$O(K^2log_2K)$。 再来计算原问题。 在最后一列强行加上一列第一位是危险的,最后再除$p$去掉这样一行的贡献。 我们把连续的非危险列以及其后面一个危险列看作一个整体,计算出概率。 设$H_i$为$i$列的答案,$ai$为上面计算出的$i$列的答案。 那么$H_n=\sum_{i=1}^{K+1}H_{n-i}* a_i$ 显然是一个常系数线性递推,可以矩乘做到$O(K^3log_2N)$,也可以利用之前的多项式取模做到$O(K^2)$。 附代码: ```cpp #include<iostream> #include<algorithm> #include<cstdio> #include<cstring> #define MAXN 2010 #define MOD 998244353LL using namespace std; long long n,k,x,y,p,q; long long val[MAXN],st[MAXN],f[MAXN],line_one[MAXN],line_two[MAXN],a[MAXN],ans[MAXN]; long long dp[MAXN][MAXN]; inline long long read(){ long long date=0,w=1;char c=0; while(c<'0'||c>'9'){if(c=='-')w=-1;c=getchar();} while(c>='0'&&c<='9'){date=date*10+c-'0';c=getchar();} return date*w; } long long mexp(long long a,long long b,long long c){ long long s=1; while(b){ if(b&1)s=s*a%c; a=a*a%c; b>>=1; } return s; } long long solve(long long x){ for(int i=0;i<=x+1;i++)dp[i][0]=1; for(int i=x;i>=1;i--) for(int j=1;j*i<=x;j++){ long long s=0; for(int k=1;k<=j;k++)s=(s+dp[i+1][k-1]*dp[i][j-k]%MOD)%MOD; s=s*p%MOD*mexp(q,i,MOD)%MOD; dp[i][j]=(dp[i+1][j]+s)%MOD; } x++; val[1]=p; for(int i=1;i<x;i++)val[i+1]=dp[1][i]*p%MOD; st[0]=1; for(int i=1;i<x;i++) for(int j=0;j<i;j++) st[i]=(st[i]+st[j]*val[i-j]%MOD)%MOD; for(int i=1;i<=x;i++)f[x-i]=MOD-val[i]; f[x]=ans[0]=a[1]=1; long long times=n+1; while(times){ if(times&1){ for(int i=0;i<=x;i++){line_one[i]=ans[i];ans[i]=0;} for(int i=0;i<=x;i++) for(int j=0;j<=x;j++) ans[i+j]=(ans[i+j]+line_one[i]*a[j]%MOD)%MOD; for(int i=2*x;i>=x;i--) for(int j=0;j<=x;j++) ans[i-x+j]=(ans[i-x+j]-ans[i]*f[j]%MOD+MOD)%MOD; } for(int i=0;i<=x;i++){line_one[i]=line_two[i]=a[i];a[i]=0;} for(int i=0;i<=x;i++) for(int j=0;j<=x;j++) a[i+j]=(a[i+j]+line_one[i]*line_two[j]%MOD)%MOD; for(int i=2*x;i>=x;i--) for(int j=0;j<=x;j++) a[i-x+j]=(a[i-x+j]-a[i]*f[j]%MOD+MOD)%MOD; times>>=1; } long long s=0; for(int i=0;i<x;i++)s=(s+st[i]*ans[i]%MOD)%MOD; memset(dp,0,sizeof(dp)); memset(a,0,sizeof(a)); memset(ans,0,sizeof(ans)); memset(st,0,sizeof(st)); return s*mexp(p,MOD-2,MOD)%MOD; } void work(){ long long ans=(solve(k)-solve(k-1)+MOD)%MOD; printf("%lld\n",ans); } void init(){ n=read();k=read();x=read();y=read(); q=x*mexp(y,MOD-2,MOD)%MOD; p=(MOD+1-q)%MOD; } int main(){ init(); work(); return 0; } ```