题解 P5050 【【模板】多项式多点求值】

Newuser

2018-12-12 22:15:11

Solution

问题:给定一个n次多项式f(x),现在请你对于i∈[1,m],求出f(ai​)。 我们考虑分治处理,对于i∈[1,m/2] , 我们构造一个g(x) = ∏(x-ai) (1<=i<=n/2),我们发现对于[1,m/2]中所有的x值代入这个函数值都为0,那么我们直接对于[1,m/2]的原函数的f都%上g(x)应该是不会影响(即消除掉若干个g(x)不会影响)。这样就可以同时达到消除大于等于x^(m/2)次项的目的,我们发现消除到最后剩下的那个常数项的值恰好就是我们要求的f(x)(因为除了常数项都被我们消除掉了), 时间复杂度为O(nlog^2n) 我们写成一个类似线段树的结构把插值的多项式存下来。 [欢迎来Newuser小站玩owo](http://www.newuser.top/2018/12/04/%E3%80%90moban%E3%80%91%E5%A4%9A%E9%A1%B9%E5%BC%8F%E7%AE%97%E6%B3%95%E6%A8%A1%E6%9D%BF/) ```cpp #include<stdio.h> #include<cstdio> #include<algorithm> #include<iostream> #include<cmath> #include<vector> using namespace std; const int maxn = 64005; const int mod = 998244353; const int g = 3; int mul(int x,int y) { return 1ll*x*y%mod; } int add(int x,int y) { x+=y; return x>=mod?x-mod:x; } int sub(int x,int y) { x-=y; return x<0?x+mod:x; } int ksm(int a,int b) { int ans = 1; for(;b;b>>=1,a=mul(a,a)) if(b&1) ans = mul(ans,a); return ans; } void ntt(int *a,int deg,int dft) { for(int i=0,j=0;i<deg;i++) { if(i<j) swap(a[i],a[j]); for(int k=(deg>>1);(j^=k)<k;k>>=1); } for(int st=1;st<deg;st<<=1) { int dwg = (dft==1) ? ksm(g,(mod-1)/(st<<1)) : ksm(g,mod-1-(mod-1)/(st<<1)); for(int i=0;i<deg;i+=(st<<1)) { int ng = 1; for(int j=i;j<i+st;j++) { int x = a[j]; int y = mul(ng,a[j+st]); ng = mul(ng,dwg); a[j] = add(x,y); a[j+st] = add(x,mod-y); } } } if(dft==1) return; int invs = ksm(deg,mod-2); for(int i=0;i<deg;i++) a[i] = mul(invs,a[i]); } int F[18][maxn*4]; int n,m; int ta[maxn*4],tb[maxn*4]; void MULL(int *a,int *b,int *c,int le) { for(int i=0;i<le;i++) ta[i]=b[i]; for(int j=0;j<le;j++) tb[j]=c[j]; ntt(ta,le,1); ntt(tb,le,1); for(int i=0;i<le;i++) ta[i]=mul(ta[i],tb[i]); ntt(ta,le,-1); for(int i=0;i<le;i++) a[i] = ta[i]; } void ginv(int deg,int *a,int *b) { if(deg==1) { b[0]=ksm(a[0],mod-2); return; } ginv(deg>>1,a,b); for(int i=0;i<deg;i++) ta[i]=a[i]; for(int i=deg;i<(deg<<1);i++) ta[i]=0; ntt(ta,deg<<1,1); ntt(b,deg<<1,1); for(int i=0;i<(deg<<1);i++) b[i]=mul(b[i],sub(2,mul(ta[i],b[i]))); ntt(b,(deg<<1),-1); for(int i=deg;i<(deg<<1);i++) b[i]=0; } int Q[maxn*4],Fr[maxn*4],Gr[maxn*4],Gri[maxn*4],tmp[maxn*4]; void gmod(int *O,int *F,int *G,int n,int m) {//F%G==O (F n G m) for(int i=0;i<=n;i++) Fr[i] = F[n-i]; for(int i=0;i<=m;i++) Gr[i] = G[m-i]; for(int i=n-m+2;i<=m;i++) Gr[i] = 0; int le = 1; for(;le<=(n-m);le<<=1); for(int i=n-m+1;i<=m;i++) Gr[i]=0; ginv(le,Gr,Gri); le = 1; for(;le<=(n<<1);le<<=1); MULL(Q,Gri,Fr,le); reverse(Q,Q+n-m+1); for(int i=n-m+1;i<le;i++) Q[i]=0; MULL(tmp,Q,G,le); for(int i=0;i<m;i++) O[i] = sub(F[i],tmp[i]); for(int i=0;i<le;i++) tmp[i] = Fr[i] = Gr[i] = Gri[i] = Q[i] = 0; } vector<int>ve[maxn*2]; int tot,rt,ls[maxn*2],rs[maxn*2],X[maxn]; int tpa[maxn*4],tpb[maxn*4],tpc[maxn*4]; void maketree(int &p,int l,int r) { p = ++tot; int len = 1; for(;len<r-l+2;len<<=1); ve[p].resize(len); // len ci if(l==r) { ve[p][0] = (mod-(X[l]%mod+mod)%mod); ve[p][1] = 1; return; } int mid = (l+r)>>1; maketree(ls[p],l,mid); maketree(rs[p],mid+1,r); int ss = ve[ls[p]].size(); for(int i=0;i<ss;i++) tpa[i] = ve[ls[p]][i]; for(int i=ss;i<len;i++) tpa[i]=0; ss = ve[rs[p]].size(); for(int i=0;i<ss;i++) tpb[i] = ve[rs[p]][i]; for(int i=ss;i<len;i++) tpb[i]=0; MULL(tpc,tpa,tpb,len); for(int i=0;i<len;i++) ve[p][i] = tpc[i]; } int G[maxn*4]; int ANS[maxn]; void DC(int dep,int p,int l,int r,int cs) { int m = r-l+1; for(int i=0;i<=m;i++) G[i] = ve[p][i]; if(cs>=m) gmod(F[dep],F[dep-1],G,cs,m); else { for(int i=0;i<=m-1;i++) F[dep][i] = F[dep-1][i]; } for(int i=0;i<=m;i++) G[i] = 0; if(l==r) { ANS[l] = F[dep][0]; return; } int mid = (l+r)>>1; DC(dep+1,ls[p],l,mid,m-1); DC(dep+1,rs[p],mid+1,r,m-1); } int main() { scanf("%d%d",&n,&m); for(int i=0;i<=n;i++) { int x; scanf("%d",&x); F[0][i] = x; } for(int i=1;i<=m;i++) scanf("%d",&X[i]); maketree(rt,1,m); for(int i=0;i<=m;i++) G[i] = ve[rt][i]; DC(1,rt,1,m,n); for(int i=1;i<=m;i++) printf("%d\n",ANS[i]); } ```