P5050 【模板】多项式多点求值 题解

以下题解仅供学习参考使用。

抄袭、复制题解,以达到刷AC率/AC数量或其他目的的行为,在洛谷是严格禁止的。

洛谷非常重视学术诚信。此类行为将会导致您成为作弊者。具体细则请查看洛谷社区规则

评论

  • GNAQ
    什么毒瘤玩意
  • 子曰子悦
    哈哈哈哈哈哈这题解有毒
  • sermoon
    仿佛预感到了插值的惨烈
  • 142857cs
    插值都出到1e5了,这个应该也可以?
  • 142857cs
    插值应该没那么惨吧,O(n^2)的插值有多少人会?
  • ButterflyDew
    哈哈哈哈哈额
  • rEdWhitE_uMbrElla
    毒瘤。。。
  • 南城忆潇湘
    原谅我笑了出来
  • olinr
    卡常巨佬啊%%%%
  • olinr
    666666
作者: 玫葵之蝶 更新时间: 2018-12-17 23:21  在Ta的博客查看 举报    41  

脑子想想就知道这种两个log还常数大的题可以暴力艹啊

显然我们可以有一个O(nm)的秦九韶暴力,这个不会你就可以退役了

开始卡常:

1.首先加register,能加的都加,出了存系数的那个不加

2.写快读快写,这个也不必非写fread,我就没写

3.循环展开,实测4层最优

4.最后开个unsigned long long,你循环展开的过程量就不用取模了

5.写个unordered_map,如果已经算过了,就直接输出(这个貌似没卵用)

然后找个夜深人静的好时候,提交就好了,你就可以AC了

#pragma GCC optimize("Ofast")
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<unordered_map>
#define LL unsigned long long
#define R register
using namespace std;
inline void read(int &x){
    x=0;int f=1;char ch=getchar();
    while(ch<'0' || ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0' && ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
    if(f==-1)x=-x;
}
char s[20];int cnt;
inline void write(int x){
    cnt=0;
    while(x){
        s[++cnt]=x%10+'0';
        x/=10;
    }
    for(R int i=cnt;i;--i)putchar(s[i]);
    putchar('\n');
}
const int N=64005,mod=998244353;
unordered_map<int,int> mp;
int a[N];
int main(){
    R int n,m,i,j,x;
    R LL b[17],c1,c2,c3,c4,now;
    read(n);read(m);
    for(i=0;i<=n;++i)read(a[i]);
    for(i=1;i<=m;++i){
        read(x);
        if(mp[x]){write(mp[x]);continue;}
        b[0]=1;
        for(j=1;j<=16;++j)b[j]=b[j-1]*x%mod;
        now=a[n];
        for(j=n-1;j-15>=0;j-=16){
            c1=now*b[16]+a[j]*b[15]+a[j-1]*b[14]+a[j-2]*b[13];
            c2=a[j-3]*b[12]+a[j-4]*b[11]+a[j-5]*b[10]+a[j-6]*b[9];
            c3=a[j-7]*b[8]+a[j-8]*b[7]+a[j-9]*b[6]+a[j-10]*b[5];
            c4=a[j-11]*b[4]+a[j-12]*b[3]+a[j-13]*b[2]+a[j-14]*b[1];
            now=(c1+c2+c3+c4+a[j-15])%mod;
        }
        for(j=n%16-1;~j;--j)now=(now*x+a[j])%mod;
        write(mp[x]=now);
    }
    return 0;
}

评论

  • _皎月半洒花
    一开始构造的多项式是m次的吧
作者: mrsrz  更新时间: 2018-12-26 18:44  在Ta的博客查看 举报    14  

我们将要求值的点均分成两份,构造多项式$P_0(x)=\prod\limits_{i=1}^{\lfloor\frac n 2\rfloor}(x-x_i)$,$P_1(x)=\prod\limits_{i=\lfloor\frac n 2\rfloor+1}^{n}(x-x_i)$。

显然,对于$\forall i\in[1,\lfloor\frac n 2 \rfloor]$,有$P_0(x_i)=0$。$P_1$同理。

我们假设多项式$D(x),R(x)$满足:$D(x)$是一个$n-\lfloor\frac n 2\rfloor$次多项式,$R(x)$是一个次数不超过$\lfloor\frac n 2\rfloor-1$的多项式,且$A(x)=P_0(x)D(x)+R(x)$。

那么对于$\forall i\in[1,\lfloor\frac n 2 \rfloor]$,有$A(x_i)=R(x_i)$。$P_1$同理可得。

由于$R(x)$的次数是$A(x)$的次数的一半,所以我们把一个规模为$n$的问题,用$O(n\log n)$的复杂度(多项式取模,详见多项式除法),转化为两个规模为$\frac n 2$的问题。

分治计算即可,由主定理得时间复杂度$O(n\log^2 n)$。

求每次的$P_0(x),P_1(x)$,可以开始用一次分治NTT预处理。此处时间复杂度也为$O(n\log^2 n)$。

常数极大就是了QAQ。

Code:

#include<cstdio>
#include<algorithm>
typedef long long LL;
const int md=998244353,N=262145;
//Poly begin
int rev[N],lim,g[20][N],M,vv;
inline void upd(int&a){a+=a>>31&md;}
inline int pow(int a,int b){
    int ret=1;
    for(;b;b>>=1,a=(LL)a*a%md)if(b&1)ret=(LL)ret*a%md;return ret;
}
inline void init(int n){
    int l=-1;
    for(lim=1;lim<n;lim<<=1)++l;M=l+1;
    for(int i=1;i<lim;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<l);vv=pow(lim,md-2);
}
void NTT(int*a,int f){
    for(int i=1;i<lim;++i)if(i<rev[i])std::swap(a[i],a[rev[i]]);
    for(int i=0;i<M;++i){
        const int*G=g[i],c=1<<i;
        for(int j=0;j<lim;j+=c<<1)
        for(int k=0;k<c;++k){
            const int x=a[j+k],y=a[j+k+c]*(LL)G[k]%md;
            upd(a[j+k]+=y-md),upd(a[j+k+c]=x-y);
        }
    }
    if(!f){
        for(int i=0;i<lim;++i)a[i]=(LL)a[i]*vv%md;
        std::reverse(a+1,a+lim);
    }
}
void INV(const int*a,int*B,int n){
    if(n==1){
        *B=pow(*a,md-2);
        return;
    }
    INV(a,B,n+1>>1);
    init(n<<1);
    static int A[N];
    for(int i=0;i<n;++i)A[i]=a[i];
    for(int i=n;i<lim;++i)A[i]=0;
    NTT(A,1),NTT(B,1);
    for(int i=0;i<lim;++i)B[i]=B[i]*(2-(LL)A[i]*B[i]%md+md)%md;
    NTT(B,0);
    for(int i=n;i<lim;++i)B[i]=0;
}
void REV(int*A,int n){std::reverse(A,A+n+1);}
void MOD(const int*a,const int*b,int*R,int n,int m){
    static int A[N],B[N],D[N];
    for(int i=0;i<n<<1;++i)D[i]=0;
    for(int i=0;i<=m;++i)B[i]=b[i];
    REV(B,m);
    for(int i=n-m+1;i<n<<1;++i)B[i]=0;
    INV(B,D,n-m+1);
    init(n-m+1<<1);
    for(int i=0;i<=n-m+1;++i)A[i]=a[n-i];
    for(int i=n-m+2;i<lim;++i)A[i]=0;
    NTT(A,1),NTT(D,1);
    for(int i=0;i<lim;++i)D[i]=(LL)D[i]*A[i]%md;
    NTT(D,0);
    REV(D,n-m);
    init(n+1<<1);
    for(int i=n-m+1;i<lim;++i)D[i]=0;
    for(int i=0;i<lim;++i)A[i]=B[i]=0;
    for(int i=0;i<=m;++i)B[i]=b[i];
    for(int i=0;i<=n;++i)A[i]=a[i];
    NTT(B,1),NTT(D,1);
    for(int i=0;i<lim;++i)B[i]=(LL)B[i]*D[i]%md;
    NTT(B,0);
    for(int i=0;i<m;++i)upd(R[i]=A[i]-B[i]);
}
//Poly end
int n,m,A[N],a[N],*P[N],len[N];
void solve(int l,int r,int o){
    if(l==r){
        len[o]=1;
        P[o]=new int[2];
        upd(P[o][0]=-a[l]),P[o][1]=1;
        return;
    }
    const int mid=l+r>>1,L=o<<1,R=L|1;
    solve(l,mid,L),solve(mid+1,r,R);
    len[o]=len[L]+len[R];
    P[o]=new int[len[o]+1];
    init(len[o]+1<<1);
    static int A[N],B[N];
    for(int i=len[L];~i;--i)A[i]=P[L][i];
    for(int i=len[L]+1;i<lim;++i)A[i]=0;
    for(int i=len[R];~i;--i)B[i]=P[R][i];
    for(int i=len[R]+1;i<lim;++i)B[i]=0;
    NTT(A,1),NTT(B,1);
    for(int i=0;i<lim;++i)A[i]=(LL)A[i]*B[i]%md;
    NTT(A,0);
    for(int i=len[o];~i;--i)P[o][i]=A[i];
}
void solve(int l,int r,int o,const int*A){
    if(l==r){printf("%d\n",*A);return;}
    const int mid=l+r>>1,L=o<<1,R=L|1;
    int B[len[o]+2<<1];
    MOD(A,P[L],B,len[o]-1,len[L]);
    solve(l,mid,L,B);
    MOD(A,P[R],B,len[o]-1,len[R]);
    solve(mid+1,r,R,B);
}
int main(){
    for(int i=0;i<19;++i){
        int*G=g[i];
        G[0]=1;
        const int gi=G[1]=pow(3,(md-1)/(1<<i+1));
        for(int j=2;j<1<<i;++j)G[j]=(LL)G[j-1]*gi%md;
    }
    scanf("%d%d",&n,&m);if(!m)return 0;
    for(int i=0;i<=n;++i)scanf("%d",A+i);
    for(int i=1;i<=m;++i)scanf("%d",a+i);
    solve(1,m,1);
    if(n>=m)MOD(A,P[1],A,n,m);
    solve(1,m,1,A);
    return 0;
}

评论

  • Judge_Cheung
    还人傻常熟大,每次用时都是我的 $\frac{1}{n}$ ,$n >> 1$
作者: bztMinamoto 更新时间: 2019-02-13 19:52  在Ta的博客查看 举报    5  

传送门

人傻常数大.jpg

因为求逆的时候没清零结果调了几个小时……

前置芝士

多项式除法,多项式求逆

什么?你不会?左转你谷模板区,包教包会

题解

首先我们要知道一个结论$$f(x_0)\equiv f(x)\pmod{(x-x_0)}$$

其中$x_0$为一个常量,$f(x_0)$也为一个常量

证明如下,设$f(x)=g(x)(x-x_0)+A$,也就是说$A$是$f(x)$对$(x-x_0)$这个多项式取模之后的结果

因为$(x-x_0)$的最高次项为$1$,所以$A$的最高次项为$0$,也就是说$A$是一个常数,即$f(x)\equiv A\pmod{(x-x_0)}$

我们把$x_0$代入上式,得$f(x_0)=g(x_0)(x_0-x_0)+A$,同理可得$f(x_0)\equiv A\pmod{(x-x_0)}$

于是我们知道上式成立

这有毛用啊$O(n\log n)$多项式取模还没我暴力快

乍一看的确没啥卵用,但是考虑取模的过程是否能优化呢?

答案是可以的,我们考虑分治。设当前分治区间为$[l,r]$,令$P_0=\prod_{i=l}^{mid}(x-x_i)$,$P_1=\prod_{i=mid+1}^r (x-x_i)$,当前已经算出了$A\equiv f(x)\pmod{\prod_{i=l}^r(x-x_i)}$,那么只要分别用$A$对$P_0$和$P_1$取模,然后继续递归下去就行了。取模之后$A(x)$的最高次项的次数变为原来的一半,问题规模也就变为原来的一半。继续递归下去就行了

时间复杂度为$O(n\log^2n)$

upd:改了改代码,常数应该会小一点,比方说分治到某个时候暴力秦九韶展开

//minamoto
#include<bits/stdc++.h>
#define R register
#define fp(i,a,b) for(R int i=(a),I=(b)+1;i<I;++i)
#define fd(i,a,b) for(R int i=(a),I=(b)-1;i>I;--i)
#define go(u) for(int i=head[u],v=e[i].v;i;i=e[i].nx,v=e[i].v)
using namespace std;
char buf[1<<21],*p1=buf,*p2=buf;
inline char getc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
int read(){
    R int res,f=1;R char ch;
    while((ch=getc())>'9'||ch<'0')(ch=='-')&&(f=-1);
    for(res=ch-'0';(ch=getc())>='0'&&ch<='9';res=res*10+ch-'0');
    return res*f;
}
char sr[1<<21],z[20];int C=-1,Z=0;
inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;}
void print(R int x){
    if(C>1<<20)Ot();if(x<0)sr[++C]='-',x=-x;
    while(z[++Z]=x%10+48,x/=10);
    while(sr[++C]=z[Z],--Z);sr[++C]='\n';
}
const int N=(1<<17)+5,P=998244353;
inline int add(R int x,R int y){return x+y>=P?x+y-P:x+y;}
inline int dec(R int x,R int y){return x-y<0?x-y+P:x-y;}
inline int mul(R int x,R int y){return 1ll*x*y-1ll*x*y/P*P;}
int ksm(R int x,R int y){
    R int res=1;
    for(;y;y>>=1,x=mul(x,x))(y&1)?res=mul(res,x):0;
    return res;
}
int r[19][N],w[2][N],lg[N],inv[19];
void Pre(){
    fp(d,1,17){
        fp(i,1,(1<<d)-1)r[d][i]=(r[d][i>>1]>>1)|((i&1)<<(d-1));
        lg[1<<d]=d,inv[d]=ksm(1<<d,P-2);
    }
    for(R int t=(P-1)>>1,i=1,x,y;i<131072;i<<=1,t>>=1){
        x=ksm(3,t),y=ksm(332748118,t),w[0][i]=w[1][i]=1;
        fp(k,1,i-1)
            w[1][k+i]=mul(w[1][k+i-1],x),
            w[0][k+i]=mul(w[0][k+i-1],y);
    }
}
int lim,d,n,m;
inline void init(R int len){lim=1,d=0;while(lim<len)lim<<=1,++d;}
void NTT(int *A,int ty){
    fp(i,0,lim-1)if(i<r[d][i])swap(A[i],A[r[d][i]]);
    for(R int mid=1;mid<lim;mid<<=1)
        for(R int j=0,t;j<lim;j+=(mid<<1))
            fp(k,0,mid-1)
                A[j+k+mid]=dec(A[j+k],t=mul(w[ty][mid+k],A[j+k+mid])),
                A[j+k]=add(A[j+k],t);
    if(!ty)fp(i,0,lim-1)A[i]=mul(A[i],inv[d]);
}
void Inv(int *a,int *b,int len){
    if(len==1)return b[0]=ksm(a[0],P-2),void();
    Inv(a,b,len>>1),lim=(len<<1),d=lg[lim];
    static int A[N],B[N];
    fp(i,0,len-1)A[i]=a[i],B[i]=b[i];fp(i,len,lim-1)A[i]=B[i]=0;
    NTT(A,1),NTT(B,1);
    fp(i,0,lim-1)A[i]=mul(A[i],mul(B[i],B[i]));
    NTT(A,0);
    fp(i,0,len-1)b[i]=dec(add(b[i],b[i]),A[i]);
    fp(i,len,lim-1)b[i]=0;
}
struct node{
    node *lc,*rc;vector<int>vec;int deg;
    void Mod(const int *a,int *r,int n){
        static int A[N],B[N],D[N];
        int len=1;while(len<=n-deg)len<<=1;
        fp(i,0,n)A[i]=a[n-i];fp(i,0,deg)B[i]=vec[deg-i];
        fp(i,n-deg+1,len-1)B[i]=0;
        Inv(B,D,len);
        lim=(len<<1),d=lg[lim];
        fp(i,n-deg+1,lim-1)A[i]=D[i]=0;
        NTT(A,1),NTT(D,1);
        fp(i,0,lim-1)A[i]=mul(A[i],D[i]);
        NTT(A,0);
        reverse(A,A+n-deg+1);
        init(n+1);
        fp(i,n-deg+1,lim-1)A[i]=0;
        fp(i,0,deg)B[i]=vec[i];fp(i,deg+1,lim-1)B[i]=0;
        NTT(A,1),NTT(B,1);
        fp(i,0,lim-1)A[i]=mul(A[i],B[i]);
        NTT(A,0);
        fp(i,0,deg-1)r[i]=dec(a[i],A[i]);
    }
    void Mul(){
        static int A[N],B[N];deg=lc->deg+rc->deg,vec.resize(deg+1),init(deg+1);
        fp(i,0,lc->deg)A[i]=lc->vec[i];fp(i,lc->deg+1,lim-1)A[i]=0;
        fp(i,0,rc->deg)B[i]=rc->vec[i];fp(i,rc->deg+1,lim-1)B[i]=0;
        NTT(A,1),NTT(B,1);
        fp(i,0,lim-1)A[i]=mul(A[i],B[i]);
        NTT(A,0);
        fp(i,0,deg)vec[i]=A[i];
    }
}pool[N],*rt;
int A[N],a[N],tot;
inline node* newnode(){return &pool[tot++];}
void solve(node* &p,int l,int r){
    p=newnode();
    if(l==r)return p->deg=1,p->vec.resize(2),p->vec[0]=P-a[l],p->vec[1]=1,void();
    int mid=(l+r)>>1;
    solve(p->lc,l,mid),solve(p->rc,mid+1,r);
    p->Mul();
}
int b[25];
void calc(node* p,int l,int r,const int *A){
    if(r-l<=512){
        fp(i,l,r){
            int x=a[i],c1,c2,c3,c4,now=A[r-l];
            b[0]=1;fp(j,1,16)b[j]=mul(b[j-1],x);
            for(R int j=r-l-1;j-15>=0;j-=16){
                c1=(1ll*now*b[16]+1ll*A[j]*b[15]+1ll*A[j-1]*b[14]+1ll*A[j-2]*b[13])%P,
                c2=(1ll*A[j-3]*b[12]+1ll*A[j-4]*b[11]+1ll*A[j-5]*b[10]+1ll*A[j-6]*b[9])%P,
                c3=(1ll*A[j-7]*b[8]+1ll*A[j-8]*b[7]+1ll*A[j-9]*b[6]+1ll*A[j-10]*b[5])%P,
                c4=(1ll*A[j-11]*b[4]+1ll*A[j-12]*b[3]+1ll*A[j-13]*b[2]+1ll*A[j-14]*b[1])%P,
                now=(0ll+c1+c2+c3+c4+A[j-15])%P;
            }
            fd(j,(r-l)%16-1,0)now=(1ll*now*x+A[j])%P;
            print(now);
        }
        return;
    }
    int mid=(l+r)>>1,b[p->deg+1];
    p->lc->Mod(A,b,p->deg-1),calc(p->lc,l,mid,b);
    p->rc->Mod(A,b,p->deg-1),calc(p->rc,mid+1,r,b);
}
int main(){
//  freopen("testdata.in","r",stdin);
    n=read(),m=read();if(!m)return 0;
    Pre();
    fp(i,0,n)A[i]=read();
    fp(i,1,m)a[i]=read();
    solve(rt,1,m);
    if(n>=m)rt->Mod(A,A,n);
    calc(rt,1,m,A);
    return Ot(),0;
}

评论

  • 还没有评论
作者: Memory_of_winter 更新时间: 2018-12-31 13:25  在Ta的博客查看 举报    2  

我的博客

题目大意:给你一个$n$次多项式$f(x)$,以及$m$个$x_i$,对于$i\in[1,m]$,求$f(x_i)$

题解:多项式多点求值

令$g(x)=\prod\limits_{i=1}^m(x-x_i)$,求出$R(x)$使得$f(x)=Q(x)\times g(x)+R(x)$。因为当$x=x_i$时,$g(x)=0$,即$f(x)=R(x)$,$f(x)$是$n$次的,$R(x)$是$m-1$次的,似乎可以使得问题缩小了

考虑分治,现在区间为$[l,r]$,令$g_L(x)=\prod\limits_{i=l}^{mid}(x-x_i)$,$g_R(x)=\prod\limits_{i=mid}^r(x-x_i)$,所以$R_L(x)=f(x)\bmod g_L(x)$,$R_R(x)=f(x)\bmod g_R(x)$。最后当$l=r$时,第$i$个的值就是当前$R(x)$的常数项。

那$g(x)$怎么算呢,分治$FFT$,可以先把每个的$g(x)$求出来,用$vector$保存一下就行了

卡点:不知道为什么,用$C++$会$MLE$,$C++11$就过了,有可能是$vector$初始化部分出锅了

C++ Code:

#include <cstdio>
#include <algorithm>
#include <vector>
const int mod = 998244353, G = 3;

namespace Math {
    inline int pw(int base, int p) {
        static int res;
        for (res = 1; p; p >>= 1, base = static_cast<long long> (base) * base % mod) if (p & 1) res = static_cast<long long> (res) * base % mod;
        return res;
    }
    inline int inv(int x) { return pw(x, mod - 2); }
}
inline void reduce(int &x) { x += x >> 31 & mod; }

#define maxn 65536
int a[maxn], ans[maxn];
namespace Poly {
#define N maxn
    int rev[N], lim, s, ilim;
    int Wn[N + 1];
    inline void clear(register int *l, const int *r) {
        if (l >= r) return ;
        while (l != r) *l++ = 0;
    }

    inline void init(const int n) {
        s = -1, lim = 1; while (lim < n) lim <<= 1, ++s; ilim = Math::inv(lim);
        for (register int i = 0; i < lim; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << s;
        const int t = Math::pw(G, (mod - 1) / lim);
        *Wn = 1; for (register int *i = Wn; i != Wn + lim; ++i) *(i + 1) = static_cast<long long> (*i) * t % mod;
    }
    inline void NTT(int *A, const int op = 1) {
        static int Wt[N];
        for (register int i = 1; i < lim; ++i) if (i < rev[i]) std::swap(A[i], A[rev[i]]);
        for (register int mid = 1; mid < lim; mid <<= 1) {
            const int t = lim / mid >> 1;
            *Wt = Wn[op ? 0 : lim];
            for (register int *i = Wt, W = 0; i != Wt + mid; ++i, W += t) *i = Wn[op ? W : lim - W];
            for (register int i = 0; i < lim; i += mid << 1) {
                for (register int j = 0; j < mid; ++j) {
                    const int X = A[i + j], Y = static_cast<long long> (Wt[j]) * A[i + j + mid] % mod;
                    reduce(A[i + j] += Y - mod), reduce(A[i + j + mid] = X - Y);
                }
            }
        }
        if (!op) for (register int *i = A; i != A + lim; ++i) *i = static_cast<long long> (*i) * ilim % mod;
    }

    std::vector<int> P[N << 1], S[N << 1];
    int C[N], D[N];
    void DC_NTT(int rt, int l, int r) {
        if (l == r) { P[rt] = {a[l], 1}; return ; }
        int mid = l + r >> 1;
        DC_NTT(rt << 1, l, mid), DC_NTT(rt << 1 | 1, mid + 1, r);
        int L = rt << 1, R = rt << 1 | 1;
        int n = P[L].size(), m = P[R].size();
        init(n + m - 1);
        std::copy(P[L].begin(), P[L].end(), C); clear(C + n, C + lim);
        std::copy(P[R].begin(), P[R].end(), D); clear(D + m, D + lim);
        NTT(C), NTT(D);
        for (int i = 0; i < lim; ++i) C[i] = static_cast<long long> (C[i]) * D[i] % mod;
        NTT(C, 0);
        P[rt].assign(C, C + n + m - 1);
    }

    int E[N];
    void INV(int *A, int *B, int n) {
        if (n == 1) {
            *B = Math::inv(*A);
            return ;
        }
        INV(A, B, n + 1 >> 1);
        init(n + n - 1);
        std::copy(A, A + n, E); clear(E + n, E + lim);
        clear(B + (n + 1 >> 1), B + lim);
        NTT(B), NTT(E);
        for (int i = 0; i < lim; ++i) B[i] = (2 + mod - static_cast<long long> (B[i]) * E[i] % mod) * B[i] % mod;
        NTT(B, 0); clear(B + n, B + lim);
    }
    int F[N];
    void DIV(int A, int n, int B, int m) {
        const int len = n - m + 1;
        init(len << 1);
        std::reverse_copy(S[A].begin(), S[A].end(), C); clear(C + len, C + lim);
        std::reverse_copy(P[B].begin(), P[B].end(), D); clear(D + len, D + lim);
        clear(F, F + lim);
        INV(D, F, len);
        NTT(C), NTT(F);
        for (int i = 0; i < lim; ++i) F[i] = static_cast<long long> (F[i]) * C[i] % mod;
        NTT(F, 0);
        clear(F + len, F + lim);
    }
    void __DIVMOD(int res, int A, int n, int B, int m) {
        if (n < m) {
            S[res].assign(S[A].begin(), S[A].end());
            return ;
        }
        DIV(A, n, B, m);
        init(n);
        std::reverse_copy(F, F + n - m + 1, C); clear(C + n - m + 1, C + lim);
        std::copy(P[B].begin(), P[B].end(), D); clear(D + m, D + lim);
        NTT(C), NTT(D);
        for (int i = 0; i < lim; ++i) C[i] = static_cast<long long> (C[i]) * D[i] % mod;
        NTT(C, 0);
        for (int i = 0; i < m - 1; ++i) reduce(C[i] = S[A][i] - C[i]);
        S[res].assign(C, C + m - 1);
    }
    void DIVMOD(int res, int A) {
        int n = S[A].size(), m = P[res].size();
        __DIVMOD(res, A, n, res, m);
    }

    void solve(int rt, int l, int r) {
        if (l == r) {
            ans[l] = S[rt][0];
            return ;
        }
        int mid = l + r >> 1;
        DIVMOD(rt << 1, rt), DIVMOD(rt << 1 | 1, rt);
        solve(rt << 1, l, mid), solve(rt << 1 | 1, mid + 1, r);
    }

    void work(int *f, int n, int m) {
        DC_NTT(1, 1, m);
        S[0].assign(f, f + n);
        DIVMOD(1, 0);
        solve(1, 1, m);
    }
#undef N
}

int n, m;
int f[maxn];
int main() {
    scanf("%d%d", &n, &m); if (!m) return 0; ++n;
    for (int i = 0; i < n; ++i) scanf("%d", f + i);
    for (int i = 1; i <= m; ++i) scanf("%d", a + i), reduce(a[i] = -a[i]);
    Poly::work(f, n, m);
    for (int i = 1; i <= m; ++i) printf("%d\n", ans[i]);
    return 0;
}

评论

  • hdhd
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
  • qurui
    %%%%%%%%%%%%%
作者: Newuser 更新时间: 2018-12-12 22:15  在Ta的博客查看 举报    2  

问题:给定一个n次多项式f(x),现在请你对于i∈[1,m],求出f(ai​)。

我们考虑分治处理,对于i∈[1,m/2] , 我们构造一个g(x) = ∏(x-ai) (1<=i<=n/2),我们发现对于[1,m/2]中所有的x值代入这个函数值都为0,那么我们直接对于[1,m/2]的原函数的f都%上g(x)应该是不会影响(即消除掉若干个g(x)不会影响)。这样就可以同时达到消除大于等于x^(m/2)次项的目的,我们发现消除到最后剩下的那个常数项的值恰好就是我们要求的f(x)(因为除了常数项都被我们消除掉了), 时间复杂度为O(nlog^2n)

我们写成一个类似线段树的结构把插值的多项式存下来。

欢迎来Newuser小站玩owo

#include<stdio.h>
#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<vector>
using namespace std;
const int maxn = 64005;
const int mod = 998244353;
const int g = 3;
int mul(int x,int y) { return 1ll*x*y%mod; }
int add(int x,int y) { x+=y; return x>=mod?x-mod:x; }
int sub(int x,int y) { x-=y; return x<0?x+mod:x; }
int ksm(int a,int b) {
    int ans = 1;
    for(;b;b>>=1,a=mul(a,a))
        if(b&1) ans = mul(ans,a);
    return ans;
}
void ntt(int *a,int deg,int dft) {
    for(int i=0,j=0;i<deg;i++) {
        if(i<j) swap(a[i],a[j]);
        for(int k=(deg>>1);(j^=k)<k;k>>=1);
    }
    for(int st=1;st<deg;st<<=1) {
        int dwg = (dft==1) ? ksm(g,(mod-1)/(st<<1)) : ksm(g,mod-1-(mod-1)/(st<<1));
        for(int i=0;i<deg;i+=(st<<1)) {
            int ng = 1;
            for(int j=i;j<i+st;j++) {
                int x = a[j]; int y = mul(ng,a[j+st]);
                ng = mul(ng,dwg);
                a[j] = add(x,y); a[j+st] = add(x,mod-y);
            }
        }
    }
    if(dft==1) return;
    int invs = ksm(deg,mod-2);
    for(int i=0;i<deg;i++) a[i] = mul(invs,a[i]);
}
int F[18][maxn*4];
int n,m;
int ta[maxn*4],tb[maxn*4];
void MULL(int *a,int *b,int *c,int le) {
    for(int i=0;i<le;i++) ta[i]=b[i];
    for(int j=0;j<le;j++) tb[j]=c[j];
    ntt(ta,le,1); ntt(tb,le,1);
    for(int i=0;i<le;i++) ta[i]=mul(ta[i],tb[i]);
    ntt(ta,le,-1);
    for(int i=0;i<le;i++) a[i] = ta[i];
}
void ginv(int deg,int *a,int *b) {
    if(deg==1) { b[0]=ksm(a[0],mod-2); return; }
    ginv(deg>>1,a,b);
    for(int i=0;i<deg;i++) ta[i]=a[i];
    for(int i=deg;i<(deg<<1);i++) ta[i]=0;
    ntt(ta,deg<<1,1); ntt(b,deg<<1,1);
    for(int i=0;i<(deg<<1);i++) b[i]=mul(b[i],sub(2,mul(ta[i],b[i])));
    ntt(b,(deg<<1),-1);
    for(int i=deg;i<(deg<<1);i++) b[i]=0;
}
int Q[maxn*4],Fr[maxn*4],Gr[maxn*4],Gri[maxn*4],tmp[maxn*4];
void gmod(int *O,int *F,int *G,int n,int m) {//F%G==O (F n G m)
    for(int i=0;i<=n;i++) Fr[i] = F[n-i];
    for(int i=0;i<=m;i++) Gr[i] = G[m-i];
    for(int i=n-m+2;i<=m;i++) Gr[i] = 0;
    int le = 1; for(;le<=(n-m);le<<=1);
    for(int i=n-m+1;i<=m;i++) Gr[i]=0;
    ginv(le,Gr,Gri);
    le = 1; for(;le<=(n<<1);le<<=1);
    MULL(Q,Gri,Fr,le);
    reverse(Q,Q+n-m+1);
    for(int i=n-m+1;i<le;i++) Q[i]=0;
    MULL(tmp,Q,G,le);
    for(int i=0;i<m;i++) O[i] = sub(F[i],tmp[i]);
    for(int i=0;i<le;i++) tmp[i] = Fr[i] = Gr[i] = Gri[i] = Q[i] = 0;
}
vector<int>ve[maxn*2];
int tot,rt,ls[maxn*2],rs[maxn*2],X[maxn];
int tpa[maxn*4],tpb[maxn*4],tpc[maxn*4];
void maketree(int &p,int l,int r) {
    p = ++tot;
    int len = 1; for(;len<r-l+2;len<<=1);
    ve[p].resize(len); // len ci
    if(l==r) {
        ve[p][0] = (mod-(X[l]%mod+mod)%mod); ve[p][1] = 1; return; 
    }
    int mid = (l+r)>>1;
    maketree(ls[p],l,mid); maketree(rs[p],mid+1,r);

    int ss = ve[ls[p]].size();
    for(int i=0;i<ss;i++) tpa[i] = ve[ls[p]][i]; for(int i=ss;i<len;i++) tpa[i]=0;

        ss = ve[rs[p]].size();
    for(int i=0;i<ss;i++) tpb[i] = ve[rs[p]][i]; for(int i=ss;i<len;i++) tpb[i]=0;

    MULL(tpc,tpa,tpb,len);
    for(int i=0;i<len;i++) ve[p][i] = tpc[i];
}
int G[maxn*4];
int ANS[maxn];
void DC(int dep,int p,int l,int r,int cs) {
    int m = r-l+1;
    for(int i=0;i<=m;i++) G[i] = ve[p][i];
    if(cs>=m) gmod(F[dep],F[dep-1],G,cs,m);
    else {
        for(int i=0;i<=m-1;i++) F[dep][i] = F[dep-1][i];
    }
    for(int i=0;i<=m;i++) G[i] = 0;
    if(l==r) {
        ANS[l] = F[dep][0];
        return;
    }
    int mid = (l+r)>>1;
    DC(dep+1,ls[p],l,mid,m-1);
    DC(dep+1,rs[p],mid+1,r,m-1);
}
int main() {
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;i++) {
        int x; scanf("%d",&x);
        F[0][i] = x;
    }
    for(int i=1;i<=m;i++) scanf("%d",&X[i]);
    maketree(rt,1,m);
    for(int i=0;i<=m;i++) G[i] = ve[rt][i];
    DC(1,rt,1,m,n);
    for(int i=1;i<=m;i++) printf("%d\n",ANS[i]);
}
 
反馈
如果你认为某个题解有问题,欢迎向洛谷反馈,以帮助更多的同学。



请具体说明理由,以增加反馈的可信度。