柒葉灬 的博客

柒葉灬 的博客

FFT & NTT总结(个人笔记)

posted on 2019-01-03 16:06:28 | under 专题总结 |

FFT & NTT总结


$FFT$ 模板(短但有精度问题的):

int limit,l,r[maxm];
struct Complex{
    double x,y;
    Complex(double xx=0,double yy=0){
        x=xx;y=yy;
    }
    Complex operator +(const Complex &b)const{
        return Complex(x+b.x,y+b.y);
    }
    Complex operator -(const Complex &b)const{
        return Complex(x-b.x,y-b.y);
    }
    Complex operator *(const Complex &b)const{
        return Complex(x*b.x-y*b.y,x*b.y+y*b.x);
    }
}a[maxm],b[maxm];
void FFT(Complex *A,int type){
    for(int i=0;i<limit;i++)
        if(i<r[i])swap(A[i],A[r[i]]);
    for(int i=1;i<limit;i<<=1){
        Complex T(cos(Pi/i),type*sin(Pi/i));
        for(int j=0;j<limit;j+=i<<1){
            Complex t(1,0);
            for(int k=0;k<i;k++,t=t*T){
                Complex x=A[j+k],y=t*A[j+k+i];
                A[j+k]=x+y;
                A[j+k+i]=x-y;
            }
        }
    }
}
void calc(long long *ans,long long *A,long long *B,int len){
    limit=1,l=0;
    while(limit<=len<<1)limit<<=1,l++;
    for(int i=1;i<limit;i++)
        r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    for(int i=0;i<=len;i++){
        a[i].x=A[i];
        b[i].x=B[i];
    }
    FFT(a,1);
    FFT(b,1);
    for(int i=0;i<=limit;i++)
        a[i]=a[i]*b[i];
    FFT(a,-1);
    for(int i=0;i<=len<<1;i++)
        ans[i]=(long long)(a[i].x/limit+0.5);
    for(int i=0;i<=limit;i++)
        a[i].x=a[i].y=b[i].x=b[i].y=0;
}

$NTT$ 模板:

void NTT(long long *A,int type){
    for(int i=0;i<limit;i++)
        if(i<r[i])swap(A[i],A[r[i]]);
    for(int i=1;i<limit;i<<=1){
        long long T=qpow(type==1?G:Gi,(P-1)/(i<<1));
        for(int j=0;j<limit;j+=i<<1){
            long long t=1;
            for(int k=0;k<i;k++,t=t*T%P){
                long long x=A[j+k],y=t*A[j+k+i]%P;
                A[j+k]=(x+y)%P;
                A[j+k+i]=(x-y+P)%P;
            }
        }
    }
}

其中: $P$ 表示模数, $G$ 表示这个模数的原根, $Gi$ 表示原根的逆元。


当求多项式相乘的时候有限制条件时怎么办?

例子:把下列 $O(n^2)$ 代码用 $FFT$ 改成 $O(nlog_2n)$

void calc(){
    for(int i=0;i<len;i++)
        for(int j=i+1;j<=len;j++)
            C[i+j]+=A[i]*B[j];
    int ans=0;
    for(int i=1;i<=len<<1;i++)
        ans+=C[i];
    cout<<ans<<endl;
}

不难发现,上列式子 $C[i+j]+=A[i] \times B[j]$ 中,要求 $i < j $

但简单的 $FFT$ 并不能实现这个限制功能。

这时候我们需要强行把一个变成负数。

即: $C[i-j]+=A[i] \times B[-j]$

因为 $i-j<0$ ,所以在负数的地方加上 $len$

即: $C[i-j+len]+=A[i] \times B[len-j]$

$i-j<0$ ,所以 $i-j+len<len$ ,

下标 $[0,len-1]$ 的地方就是我们要的答案。

得到最终答案:翻转B数组,C[0 -> len-1]就是合法的答案。

拓展 :

若 $i \leq j$ 则是 $C[0 -> len]$

若 $i>j$ 则是 $C[len+1 -> len*2]$

若 $i \geq j$ 则是 $C[len -> len*2]$


上面的例子太简单了,换一个难的。

void calc(){
    for(int i=0;i<len;i++)
        for(int j=i+1;j<=len;j++)
            C[i+j]+=A[i]*B[j];
    for(int i=0;i<=len<<1;i++)
        cout<<C[i]<<" ";
}

上面的,不会......

update in 2019/7/30 :

可以用类似分治的操作左半区间的 $A_i$ 乘上右半区间的 $B_i$ ,

复杂度 $O(nlog^2n)$


任意模数:

void calc(int n,int m){
    limit=1,l=0;
    while(limit<=n+m)limit<<=1,l++;
    for(int i=1;i<limit;i++)
        r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    for(int i=0;i<=n;i++){
        A[i].x=num1[i]&16383;
        C[i].x=num1[i]>>14;
    }
    for(int i=0;i<=m;i++){
        B[i].x=num2[i]&16383;
        D[i].x=num2[i]>>14;
    }
    FFT(A,1);FFT(B,1);FFT(C,1);FFT(D,1);
    for(int i=0;i<=limit;i++){
        Complex a1=A[i],a2=C[i],b1=B[i],b2=D[i];
        A[i]=a1*b1;
        B[i]=a1*b2;
        C[i]=a2*b1;
        D[i]=a2*b2;
    }
    FFT(A,-1);FFT(B,-1);FFT(C,-1);FFT(D,-1);
    for(int i=0;i<=n+m;i++){
        long long x1=A[i].x+0.5,x2=B[i].x+0.5,x3=C[i].x+0.5,x4=D[i].x+0.5;
        x1%=P;x2%=P;x3%=P;x4%=P;//!!!
        res[i]=(x1%P+(x2<<14)%P+(x3<<14)%P+(x4<<28)%P)%P;
    }
    for(int i=0;i<=limit;i++){
        A[i].x=A[i].y=B[i].x=B[i].y=C[i].x=C[i].y=D[i].x=D[i].y=0;
    }
}

上面的代码调用了 $8$ 次FFT,太慢了。

下面是重点要背的

其中不仅是多模数的处理

还有是精度处理良好的FFT

优化:

void FFT(Complex *A){
    for(int i=1;i<limit;i++)
        if(i<r[i])swap(A[i],A[r[i]]);
    t[0].x=1;
    for(int i=1;i<limit;i<<=1){
        Complex T=Complex{cos(Pi/i),sin(Pi/i)};

        for(int j=i-2;j>=0;j-=2){
            t[j]=t[j>>1];
            t[j+1]=T*t[j];
        }

        for(int j=0;j<limit;j+=i<<1){
            for(int k=0;k<i;k++){
                Complex x=A[j+k],y=t[k]*A[j+k+i];
                A[j+k]=x+y;
                A[j+k+i]=x-y;
            }
        }
    }
}
void calc(int len){
    limit=1,l=0;
    while(limit<=len<<1)limit<<=1,l++;
    for(int i=1;i<limit;i++)
        r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    for(int i=0;i<=len;i++){
        A[i].x=num1[i]&32767;
        A[i].y=num1[i]>>15;
        B[i].x=num2[i]&32767;
        B[i].y=num2[i]>>15;
    }
    for(int i=len+1;i<limit;i++)
        A[i].clear(),B[i].clear();
    FFT(A);FFT(B);
    for(int i=0;i<limit;i++){
        int j=(limit-1)&(limit-i);
        C[j]=(Complex){0.5*(A[i].x+A[j].x),0.5*(A[i].y-A[j].y)}*B[i];
        D[j]=(Complex){0.5*(A[i].y+A[j].y),0.5*(A[j].x-A[i].x)}*B[i];
    }
    FFT(C);FFT(D);
    for(int i=len;i<=len<<1;i++){
        long long x1=C[i].x/limit+0.5;
        long long x2=C[i].y/limit+0.5;
        long long x3=D[i].x/limit+0.5;
        long long x4=D[i].y/limit+0.5;
        x1%=P;x2%=P;x3%=P;x4%=P;
        res[i]=(x1+(x2<<15)+(x3<<15)+(x4<<30))%P;
    }
}