loj6077(洛谷2513进化版)

yukuai26

2018-01-11 09:55:14

Solution

一道经典老题,山东2017集训出了n,k<=10w的题目,然后被机房毒瘤出题人抄来,虽然机房人都一眼秒了nk,但是机房2个集训队大佬就1个ac的,其他人。。nk暴力分拿了就跑。。。。 经过漫长的理解(看题解),我似乎终于懂了 这边主要是来讲一下一种全新的思路,可以做到k\*sqrt(k) 首先,从小往大加数,第i次加数可以使逆序对增加 0--n-1(套路), 那么我们有了一个不定方程 g1+g2+g3....+gn=k 0<=g[i]<i 接下来我们容斥考虑合法的解 设f[i][j]=表示i个数和为j的个数,显然i&1==1 时为-,0时为+ f[i][j]显然可以dp,我们考虑一个数列,可以所有数+1,当没有1时,插入一个1, 这样就可以dp了,可以发现因为 i\*(i+1)/2<=k,所以dp预处理的时间是k(sqrt(k)) 对于每个f[i][j] 显然他就答案的贡献是 设m为和为k-j,n个元的不定方程解的个数 if (i&1) 贡献=f[i][j]\*m; else 贡献=-f[i][j]\*m; 通过挡板法可以用组合数求解,这样就完美解决了 但是。。因为组合数的特性(代码中用了阶乘求逆元),山东集训的mod是1e9+7,这边是1w(不是素数,丧心病狂)。所以这边这个题解的代码无法过掉,(那边过了),所以就不要复制了哦,这仅仅是一个神奇的思路,当然,你可以复制我的再帮我手写ex卢卡斯和crt合并过掉此题。。 另外,机房集训队大佬说了一种n lon n 的多项式乘的做法,更加优秀(然而我并没有听懂) 希望会的人也来发题解,谢谢qaq ```cpp #include <cmath> #include <cstdio> #include <cstring> #include <iostream> #include <algorithm> #include <queue> #include <set> #define ll long long #define N 100005 #define inf 1000000005 #define mod 1000000007 #include <map> //注意mod的修改 using namespace std; int num,fac[N*2],inv[N*2]; inline int read(){ char c=getchar();int tot=1;while ((c<'0'|| c>'9')&&c!='-') c=getchar();if (c=='-'){tot=-1;c=getchar();} int sum=0;while (c>='0'&&c<='9'){sum=sum*10+c-'0';c=getchar();}return sum*tot;} int qpow(int a,int b){ int ans=1,tmp=a; for(;b;b>>=1,tmp=1ll*tmp*tmp%mod) if(b&1)ans=1ll*ans*tmp%mod; return ans; } inline void wr(int x){if (x<0) {putchar('-');wr(-x);return;}if(x>=10)wr(x/10);putchar(x%10+'0');} inline void wrn(int x){wr(x);putchar('\n');} inline void wri(int x){wr(x);putchar(' ');} int C(int x,int y){ if(x>y)return 0; return (ll)fac[y]*inv[x]%mod*inv[y-x]%mod; } //int nedge,Next[N*2],head[N],to[N*2]; //void add(int a,int b){nedge++;Next[nedge]=head[a];head[a]=nedge;to[nedge]=b;} int n,m,sum,ans,f[500][N]; int main(){ //freopen("reverse.in","r",stdin); //freopen("reverse.out","w",stdout); n=read();m=read(); fac[0]=inv[0]=1; for(int i=1;i<=n+m;i++)fac[i]=(ll)fac[i-1]*i%mod; inv[n+m]=qpow(fac[n+m],mod-2); for(int i=n+m-1;i;i--)inv[i]=(ll)inv[i+1]*(i+1)%mod; num=sqrt(m*2)+5; f[0][0]=1; for (int i=1;i<=num;i++){ for (int j=i;j<=m;j++){ f[i][j]=(f[i][j-i]+f[i-1][j-i])%mod; if (j>=n+1) f[i][j]=(f[i][j]-f[i-1][j-n-1])%mod; } } for (int i=0;i<=m;i++){ sum=0; for (int j=0;j<=num;j++){ if (j&1) sum-=f[j][i]; else sum+=f[j][i]; sum=(sum%mod+mod)%mod; } ans=(ans+(ll)sum*C(n-1,n+m-i-1))%mod; } wrn(ans); return 0; } ```