[MtOI2019] T6 Solution

NaCly_Fish

2019-08-24 15:49:00

Solution

upd:之前式子有一点锅,现已修复。 这其实是一个很水的套路题。。 首先容易看出来 $f(x,0)$ 是个线性递推的形式,要求的是其 $k$ 阶前缀积。 要求乘积不太好搞,可以对 $2$ 取一下对数,化乘为加。 于是问题转化为: 一个数列 $a$: $$a_n=n\space(n\le42)$$ $$a_n=\sum\limits_{i=1}^{42}ia_{n-i}\space(n\ge 43)$$ 求它 $k$ 阶前缀和的第 $n$ 项。 **** 关于线性递推式的高阶前缀和有一个优美的性质。 设数列 $a$ 的递推系数为 $f$,那么在 $f$ 前面加个 $-1$ ,然后做 $k$ 阶差分得到的序列即 $a$ 的 $k$ 阶前缀和的递推式。( 当然要在后面扩展 $k$ 项,同时最后去掉 $-1$ ) 在此简短证明一下,设: $$a_n=\sum\limits_{i=1}^kf_ia_{n-i}$$ $$b_n=\sum\limits_{i=1}^na_i$$ $$b_n=b_{n-1}+a_n=b_{n-1}+\sum\limits_{i=1}^kf_ia_{n-i}$$ $$= b_{n-1}+\sum\limits_{i=1}^kf_i(b_{n-i}-b_{n-i-1})$$ 后面的那个求和展开,可以得到很多形如 $(f_i-f_{i-1})b_{n-i}$ 的式子,这就是一个很明显的差分形式,接下来的证明就很容易了。 得到 $k$ 阶前缀和的递推式后,把 $a$ 的 $k$ 阶前缀和的前几项也求出来,然后直接上线性递推板子即可。 不过要注意的是我们刚才取了个 $\log$,所以以上运算都要对 $\color{red} 998244352$ 取模,这需要用到任意模数。$7$ 次 FFT 的做法常数过大,不能通过;需要使用 $4$ 次 FFT 的做法。 还有就是模数不是素数时,算组合数很麻烦,所以直接用倍增多项式快速幂计算高阶差分或前缀和即可。 时间复杂度 $\Theta(k\log^2 k+k\log k\log n)$。 std: ```cpp #include<cstdio> #include<iostream> #include<cstring> #include<algorithm> #include<cmath> #define N 65539 #define ll long long #define reg register #define p 998244352 #define pi 3.141592653589793 using namespace std; struct complex{ double x,y; inline complex(double x=0,double y=0):x(x),y(y){} inline complex operator + (const complex& b) const{ return complex(x+b.x,y+b.y); } inline complex operator - (const complex& b) const{ return complex(x-b.x,y-b.y); } inline complex operator * (const complex& b) const{ return complex(x*b.x-y*b.y,x*b.y+y*b.x); } inline complex operator / (const int& b) const{ return complex(x/b,y/b); } inline complex operator ~ () const{ return complex(x,-y); } }rt[N]; struct matrix{ int a[43][43]; int siz; inline matrix(int _siz=0):siz(_siz){ memset(a,0,sizeof(a)); } inline matrix operator * (const matrix& b) const{ matrix res = matrix(siz); for(reg int i=0;i!=siz;++i) for(reg int j=0;j!=siz;++j) for(reg int k=0;k!=siz;++k) res.a[i][j] = (res.a[i][j]+(ll)a[i][k]*b.a[k][j])%p; return res; } }; inline matrix mat_pow(matrix a,ll t){ matrix res = matrix(a.siz); for(reg int i=0;i!=a.siz;++i) res.a[i][i] = 1; while(1){ if(t&1) res = res*a; t >>= 1; if(t==0) break; a = a*a; } return res; } int rev[N]; int siz; inline int add(int a,int b){ return a+b>=p?a+b-p:a+b; } inline int dec(int a,int b){ return a<b?a-b+p:a-b; } inline int getlen(int n){ return 1<<(32-__builtin_clz(n)); } inline void init(int n){ int lim = 1; while(lim<=n) lim <<= 1,++siz; for(reg int i=1;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1)); rt[lim>>1] = complex(1,0); for(reg int i=1;i!=(lim>>1);++i) rt[i+(lim>>1)] = complex(cos(pi*2/lim*i),sin(pi*2/lim*i)); for(reg int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1]; } inline void dft(complex *f,int lim){ static complex a[N]; int shift = siz-__builtin_ctz(lim); for(reg int i=0;i!=lim;++i) a[rev[i]>>shift] = f[i]; for(reg int mid=1;mid!=lim;mid<<=1) for(reg int j=0;j!=lim;j+=(mid<<1)) for(reg int k=0;k!=mid;++k){ complex x = a[j|k|mid]*rt[mid|k]; a[j|k|mid] = a[j|k]-x; a[j|k] = a[j|k]+x; } for(reg int i=0;i!=lim;++i) f[i] = a[i]; } inline void idft(complex *f,int lim){ reverse(f+1,f+lim); dft(f,lim); for(reg int i=0;i!=lim;++i) f[i] = f[i]/lim; } inline void multiply(const int *A,const int *B,int n,int m,int *R,int len,bool flag){ static complex f[N],g[N],h[N],q[N]; complex t,f0,f1,g0,g1; ll x,y,z; int lim = getlen(n+m); for(reg int i=0;i!=lim;++i){ f[i] = complex(A[i]>>15,A[i]&32767); g[i] = complex(B[i]>>15,B[i]&32767); } dft(f,lim); if(flag) for(reg int i=0;i!=lim;++i) g[i] = f[i]; else dft(g,lim); for(reg int i=0;i!=lim;++i){ t = ~f[i?lim-i:0]; f0 = (f[i]-t)*complex(0,-0.5),f1 = (f[i]+t)*0.5; t = ~g[i?lim-i:0]; g0 = (g[i]-t)*complex(0,-0.5),g1 = (g[i]+t)*0.5; h[i] = f1*g1; q[i] = f1*g0 + f0*g1 + f0*g0*complex(0,1); } idft(h,lim),idft(q,lim); for(reg int i=0;i<=len;++i){ x = (ll)(h[i].x+0.5)%p<<30; y = (ll)(q[i].x+0.5)<<15; z = q[i].y+0.5; R[i] = (x+y+z)%p; } memset(R+len+1,0,(lim-len)<<2); } inline void inverse(const int *f,int n,int *R){ static int g[N],h[N],st[30]; memset(g,0,getlen(n<<1)<<2); int top = 0,lim =1 ; while(n){ st[++top] = n; n >>= 1; } g[0] = 1; while(top--){ n = st[top+1]; while(lim<=(n<<1)) lim <<= 1; memcpy(h,f,(n+1)<<2); memset(h+n+1,0,(lim-n)<<2); multiply(h,g,n,n>>1,h,n,false); multiply(h,g,n,n>>1,h,n,false); for(reg int i=(n>>1);i<=n;++i) g[i] = dec(add(g[i],g[i]),h[i]); } memcpy(R,g,(n+1)<<2); } inline void divide(const int *f,const int *ig,int n,int m,int *R){ static int A[N],B[N]; memcpy(A,f,(n+1)<<2),memcpy(B,ig,(m+1)<<2); reverse(A,A+n+1); int tt = n-m,lim = getlen((n-m)<<1); memset(A+tt+1,0,(lim-tt)<<2); for(reg int i=min(m,tt)+1;i!=lim;++i) B[i] = 0; multiply(A,B,tt,tt,R,n-m,false); reverse(R,R+tt+1); } inline void mod(const int *f,const int *g,const int *ig,int n,int m,int *R){ if(n<m) return; static int A[N],B[N]; memcpy(B,f,(n+1)<<2); int lim = getlen(n); divide(f,ig,n,m,R); memcpy(A,g,(m+1)<<2); memset(A+m+1,0,(lim-m+2)<<2); memset(R+n-m+1,0,(lim-n+m+1)<<2); multiply(A,R,m,n-m,R,m-1,false); for(reg int i=0;i!=m;++i) R[i] = dec(B[i],R[i]); } void mod_power(const int *G,int k,ll t,int *R){ int f[N],g[N],ig[N]; memset(f,0,sizeof(f)); memset(g,0,sizeof(g)); memset(ig,0,sizeof(ig)); memcpy(ig,G,(k+1)<<2); reverse(ig,ig+k+1); inverse(ig,k,ig); int n = 1,m = 0; f[1] = g[0] = 1; while(1){ if(t&1){ multiply(f,g,n,m,g,n+m,false); mod(g,G,ig,n+m,k,g); m = min(n+m,k-1); } t >>= 1; if(t==0) break; multiply(f,f,n,n,f,n<<1,true); mod(f,G,ig,n<<1,k,f); n = min(n<<1,k-1); } memcpy(R,g,k<<2); } void poly_pow(const int *f,int n,int k,int *R){ static int g[N],h[N]; memset(g,0,sizeof(g)),memset(h,0,sizeof(h)); memcpy(g,f,(n+1)<<2); h[0] = 1; int m = 0; while(1){ if(k&1){ multiply(h,g,n,m,h,n+m,false); m += n; } k >>= 1; if(k==0) break; multiply(g,g,n,n,g,n<<1,true); n <<= 1; } memcpy(R,h,(m+1)<<2); } inline int poww(int a,int t,int m){ int res = 1; while(t){ if(t&1) res = (ll)res*a%m; a = (ll)a*a%m; t >>= 1; } return res; } int k,lim; int a[N],f[N],c[N],F[N],G[N]; ll n; int special(){ if(k==1){ for(int i=1;i<=42;++i) a[42] += a[42-i]*i; for(int i=1;i<=42;++i) a[i] = add(a[i],a[i-1]); for(int i=43;i;--i) f[i] = dec(f[i],f[i-1]); } int siz = 42+k; matrix A = matrix(siz); for(reg int i=0;i!=siz;++i) A.a[i][0] = f[i+1]; for(reg int i=1;i!=siz;++i) A.a[i-1][i] = 1; A = mat_pow(A,n-1); int res = 0; for(reg int i=0;i!=siz;++i) res = (res+(ll)a[siz-1-i]*A.a[i][siz-1])%p; return poww(2,res,p+1); } int main(){ int ans = 0; scanf("%lld%d",&n,&k); if(n==1){ putchar('2'); return 0; } f[0] = p-1; for(reg int i=1;i<=42;++i) a[i-1] = f[i] = i; if(k<=1){ printf("%d",special()); return 0; } init((k+42)<<1); c[0] = 1,c[1] = p-1; poly_pow(c,1,k,c); lim = k+42; for(reg int i=0;i<=42;++i) for(reg int j=0;j<=k;++j) G[i+j] = (G[i+j]+(ll)f[i]*c[j])%p; inverse(c,lim-1,c); for(reg int i=42;i!=lim;++i) for(reg int j=1;j<=42;++j) a[i] = (a[i]+(ll)a[i-j]*j)%p; multiply(a,c,lim-1,lim-1,a,lim-1,false); reverse(G,G+lim+1); for(reg int i=0;i<=lim;++i) G[i] = G[i]?p-G[i]:0; mod_power(G,lim,n-1,F); for(reg int i=0;i!=lim;++i) ans = (ans+(ll)F[i]*a[i])%p; printf("%d",poww(2,ans,p+1)); return 0; } ``` ps:其实可以做到 $\Theta(k\log k + \log n)$,需要带个几百倍的常数,这里就不写了(逃