P3824 [NOI2017]泳池
斯德哥尔摩
2018-07-30 12:18:08
[P3824 [NOI2017]泳池](https://www.luogu.org/problemnew/show/P3824)
这里有一篇大佬的[博客](https://www.luogu.org/blog/ShadowassIIXVIIIIV/solution-p3824)。
我就是看这个看懂的。(我好菜啊。。。)
首先设$ans_i$是安全面积$S<=i$的概率,那么问题可以转成:
求$ans_k-ans_{k-1}$。
注意到如果一个竖行$x$满足$A_{x,1}$是危险的话,这一行就没用。
然后连续的$A_{x,1}$是安全的必须不超过$K$个,一列最高连续不超过$K$。
我们可以以此为划分。
设$f[i][j]$为一个宽度为$i$的区域,其中最低的一根柱子恰好是$j$的方案数。
设$g[i][j]$表示一个宽度为$i$的区域,其中最低的一根柱子$>=j$的方案数。
那么枚举最靠左的一根恰好是$j$的柱子的位置,可以得到:
$$f[i][j]=\sum_{k=1}^iq^j*(1-q)*g[k-1][j+1]* g[i-k][j]$$
注意到这个$dp$的复杂度目测是$K^3$。
但$i* j<=K$,所以时间复杂度应为$O(K^2log_2K)$。
再来计算原问题。
在最后一列强行加上一列第一位是危险的,最后再除$p$去掉这样一行的贡献。
我们把连续的非危险列以及其后面一个危险列看作一个整体,计算出概率。
设$H_i$为$i$列的答案,$ai$为上面计算出的$i$列的答案。
那么$H_n=\sum_{i=1}^{K+1}H_{n-i}* a_i$
显然是一个常系数线性递推,可以矩乘做到$O(K^3log_2N)$,也可以利用之前的多项式取模做到$O(K^2)$。
附代码:
```cpp
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#define MAXN 2010
#define MOD 998244353LL
using namespace std;
long long n,k,x,y,p,q;
long long val[MAXN],st[MAXN],f[MAXN],line_one[MAXN],line_two[MAXN],a[MAXN],ans[MAXN];
long long dp[MAXN][MAXN];
inline long long read(){
long long date=0,w=1;char c=0;
while(c<'0'||c>'9'){if(c=='-')w=-1;c=getchar();}
while(c>='0'&&c<='9'){date=date*10+c-'0';c=getchar();}
return date*w;
}
long long mexp(long long a,long long b,long long c){
long long s=1;
while(b){
if(b&1)s=s*a%c;
a=a*a%c;
b>>=1;
}
return s;
}
long long solve(long long x){
for(int i=0;i<=x+1;i++)dp[i][0]=1;
for(int i=x;i>=1;i--)
for(int j=1;j*i<=x;j++){
long long s=0;
for(int k=1;k<=j;k++)s=(s+dp[i+1][k-1]*dp[i][j-k]%MOD)%MOD;
s=s*p%MOD*mexp(q,i,MOD)%MOD;
dp[i][j]=(dp[i+1][j]+s)%MOD;
}
x++;
val[1]=p;
for(int i=1;i<x;i++)val[i+1]=dp[1][i]*p%MOD;
st[0]=1;
for(int i=1;i<x;i++)
for(int j=0;j<i;j++)
st[i]=(st[i]+st[j]*val[i-j]%MOD)%MOD;
for(int i=1;i<=x;i++)f[x-i]=MOD-val[i];
f[x]=ans[0]=a[1]=1;
long long times=n+1;
while(times){
if(times&1){
for(int i=0;i<=x;i++){line_one[i]=ans[i];ans[i]=0;}
for(int i=0;i<=x;i++)
for(int j=0;j<=x;j++)
ans[i+j]=(ans[i+j]+line_one[i]*a[j]%MOD)%MOD;
for(int i=2*x;i>=x;i--)
for(int j=0;j<=x;j++)
ans[i-x+j]=(ans[i-x+j]-ans[i]*f[j]%MOD+MOD)%MOD;
}
for(int i=0;i<=x;i++){line_one[i]=line_two[i]=a[i];a[i]=0;}
for(int i=0;i<=x;i++)
for(int j=0;j<=x;j++)
a[i+j]=(a[i+j]+line_one[i]*line_two[j]%MOD)%MOD;
for(int i=2*x;i>=x;i--)
for(int j=0;j<=x;j++)
a[i-x+j]=(a[i-x+j]-a[i]*f[j]%MOD+MOD)%MOD;
times>>=1;
}
long long s=0;
for(int i=0;i<x;i++)s=(s+st[i]*ans[i]%MOD)%MOD;
memset(dp,0,sizeof(dp));
memset(a,0,sizeof(a));
memset(ans,0,sizeof(ans));
memset(st,0,sizeof(st));
return s*mexp(p,MOD-2,MOD)%MOD;
}
void work(){
long long ans=(solve(k)-solve(k-1)+MOD)%MOD;
printf("%lld\n",ans);
}
void init(){
n=read();k=read();x=read();y=read();
q=x*mexp(y,MOD-2,MOD)%MOD;
p=(MOD+1-q)%MOD;
}
int main(){
init();
work();
return 0;
}
```