题解 P3824 【[NOI2017]泳池】

Salamander

2018-05-23 16:43:34

Solution

首先,恰好等于$k$的概率可以通过差分转化为小于等于$k$的概率减去小于等于$k-1$的概率。 考虑怎么求小于等于$k$的概率。设$f_n$为底部宽为$n$时的答案,那么我们可以枚举最后连续的一段白色的长度,即$\displaystyle f_n=\sum_{i=0}^{k}f_{n-i-1}\cdot\sum_{j=2}^{\lfloor{k/i}\rfloor+1}dp_{i,j}$。其中$dp_{i,j}$表示一个宽为$i$的矩形,高度为$1$至$j-1$都全部是白色,高度为$j$的一行开始出现障碍,并且包含的所有贴住底部的白色矩形(包括宽不一定为$i$和高超过$j$的那些)大小都不超过$k$的概率。记$g_{i,j}=\sum_{l\geq j}dp_{i,l}$。 如果我们可以求出$dp$和$g$数组,那么转移方程就变成了齐次常系数线性递推的形式:$\displaystyle f_n=\sum_{i=0}^{k}f_{n-i-1}\cdot g_{i,2}$。注意到$k$比较大,传统的矩阵乘法为$O(k^3\log n)$无法通过,可以用多项式取模的方法来做。 先考虑怎么求出$dp$和$g$。枚举高度为$j$处的第一个障碍, $\displaystyle dp_{i,j}=[i\cdot(j-1)\leq k](1-p)p^{j-1}\sum_{l=1}^{i}\left(\sum_{q>j}dp_{l-1,q}\right)\left(\sum_{q\geq j}dp_{i-l,q}\right)$ 前缀和优化一下,就可以做到$O(k^2\log k)$了,因为合法的状态中要求$i\cdot(j-1)\leq k$,所以一共有$\sum_{i=1}^{k}\lfloor{k/i}\rfloor=O(k\log k)$个非零状态,每次转移为$O(k)$,已经可以通过此题。实际上还可以用FFT优化到$k\log^2k$。 $~$ 下面是用多项式取模优化齐次常系数线性递推的方法: 对于递推$f_n=\sum_{i=1}^ka_kf_{n-k}$ 设矩阵乘法中的转移矩阵为$A$,设初始状态矩阵为$B[f_k,f_{k-1},...,f_1]$,我们现在要求$BA^n$,即求$A^n$。 考虑$A$的特征多项式$\lambda(x)=det(xI-A)$,$I$为单位矩阵。由于转移矩阵的特殊性,在此处我们可以手动求一下特征多项式,$\lambda(x)=x^k-\sum_{i=1}^ka_ix^{k-i}$,是一个$k$次多项式,并且满足$\lambda(A)=0$。 如果我们把$A^n$表示成若干倍的$\lambda(A)$,即$A^n=\lambda(A)g(A)+r(A)$。由于$\lambda(A)=0$,所以$A^n=r(A)$。而$r(A)=A^n\mod\lambda(A)$,我们只需要求出$x^n$模特征多项式后的多项式,然后将$A$带入即可求出$A^n$。 $~$ 多项式取模可以使用Picks的方法,用FFT做到$k\log k$,加上多项式快速幂复杂度就是$O(k\log k\log n)$,边乘边取模。当然在这里没有必要,直接暴力取模即可。 $~$ 如何暴力取模?假设要求$f(x)\mod g(x)$,设$f(x)=g(x)h(x)+r(x)$,我们就是要求$r(x)$,并且要求$r(x)$的次数小于$g(x)$的次数。不妨将$g(x)$看成$0$,那么此时$f(x)$就等于$r(x)$。设$g(x)=\sum_{i=0}^kb_ix^i=0$,那么我们可以得出$\displaystyle x^k=-\sum_{i=0}^{k-1}\frac{b_i}{b_k}x^i$,我们只需要把$f(x)$中系数过大的一部分迭代,即可将系数一步步压到比$k$小。 两个$k$次多项式相乘,超出的部分系数是$O(k)$的,所以暴力取模的复杂度为$O(k^2)$,套上快速幂就是$O(k^2\log n)$,可以通过此题。 $~$ 然后我们考虑怎么统计答案。我们已经用一个关于$A$的$k-1$次多项式$A^n=R(A)=\sum_{i=0}^{k-1}r_iA^i$表示出了$A^n$。 根据矩阵乘法的分配律,$B\cdot R(A)=\sum_{i=0}^{k-1}r_i\cdot B\cdot A^i$。我们要的结果是矩阵的第一个元素,而$B\cdot A^i$的第一个元素就是$f_{k+i}$,所以我们要先求出$f_k,f_{k+1},...,f_{2k}$,然后直接扫一遍统计答案即可。 $~$ 复杂度$O(k^2(\log k+\log n))$,其实可以做到$O(k\log k(\log k+\log n))$。 我代码中的dp部分和上面写的转移不太一样。 ``` #include<bits/stdc++.h> #define For(i,_beg,_end) for(int i=(_beg),i##end=(_end);i<=i##end;++i) #define Rep(i,_beg,_end) for(int i=(_beg),i##end=(_end);i>=i##end;--i) template<typename T>T Max(const T &x,const T &y){return x<y?y:x;} template<typename T>T Min(const T &x,const T &y){return x<y?x:y;} template<typename T>int chkmax(T &x,const T &y){return x<y?(x=y,1):0;} template<typename T>int chkmin(T &x,const T &y){return x>y?(x=y,1):0;} template<typename T>void read(T &x){ T f=1;char ch=getchar(); for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1; for(x=0;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0'; x*=f; } typedef long long LL; const int N=1010,mod=998244353; int n,m; LL a,b,p,q,dp[N][N],g[N][N],pw[N]; LL A[N],f[N<<1]; LL power(LL,LL); LL Solve(int); int main(){ read(n);read(m);read(a);read(b); p=a*power(b,mod-2)%mod;q=(mod+1-p)%mod; pw[0]=1; For(i,1,m) pw[i]=pw[i-1]*p%mod; printf("%lld\n",(Solve(m)-Solve(m-1)+mod)%mod); return 0; } LL Solve(int k){ memset(dp,0,sizeof dp); memset(g,0,sizeof g); For(i,1,k+2) g[0][i]=dp[0][i]=1; For(i,1,k) Rep(j,k/i+1,2){ For(l,1,i) dp[i][j]=(dp[i][j]+g[l-1][j+1]*g[i-l][j]%mod*pw[l-1]%mod*q)%mod; g[i][j]=(g[i][j+1]*pw[i]+dp[i][j])%mod; } memset(A,0,sizeof A); For(i,0,k) A[i+1]=q*g[i][2]%mod*pw[i]%mod; memset(f,0,sizeof f); f[0]=1; For(i,1,k){ f[i]=g[i][2]*pw[i]%mod; For(j,1,i) f[i]=(f[i]+A[j]*f[i-j])%mod; } k++; For(i,k,k<<1) For(j,1,k) f[i]=(f[i]+A[j]*f[i-j])%mod; if(n<=k)return f[n]; int y=n-k,len=1,L=0; LL res[N<<2],tmp[N<<2],x[N<<2]; memset(res,0,sizeof res); memset(x,0,sizeof x); res[0]=1;x[1]=1; for(;y;y>>=1){ if(y&1){ For(i,0,L+len) tmp[i]=0; For(i,0,L) For(j,0,len) tmp[i+j]=(tmp[i+j]+res[i]*x[j])%mod; L+=len; Rep(i,L,k) For(j,1,k) tmp[i-j]=(tmp[i-j]+tmp[i]*A[j])%mod; chkmin(L,k-1); For(i,0,L) res[i]=tmp[i]; } For(i,0,len+len) tmp[i]=0; For(i,0,len) For(j,0,len) tmp[i+j]=(tmp[i+j]+x[i]*x[j])%mod; len<<=1; Rep(i,len,k) For(j,1,k) tmp[i-j]=(tmp[i-j]+tmp[i]*A[j])%mod; chkmin(len,k-1); For(i,0,len) x[i]=tmp[i]; } LL ans=0; For(i,0,k-1) ans=(ans+res[i]*f[i+k])%mod; return ans; } LL power(LL x,LL y){ LL res=1; for(;y;y>>=1,x=x*x%mod) if(y&1) res=res*x%mod; return res; } ```