loj6077(洛谷2513进化版)
yukuai26
2018-01-11 09:55:14
一道经典老题,山东2017集训出了n,k<=10w的题目,然后被机房毒瘤出题人抄来,虽然机房人都一眼秒了nk,但是机房2个集训队大佬就1个ac的,其他人。。nk暴力分拿了就跑。。。。
经过漫长的理解(看题解),我似乎终于懂了
这边主要是来讲一下一种全新的思路,可以做到k\*sqrt(k)
首先,从小往大加数,第i次加数可以使逆序对增加 0--n-1(套路),
那么我们有了一个不定方程
g1+g2+g3....+gn=k 0<=g[i]<i
接下来我们容斥考虑合法的解
设f[i][j]=表示i个数和为j的个数,显然i&1==1 时为-,0时为+
f[i][j]显然可以dp,我们考虑一个数列,可以所有数+1,当没有1时,插入一个1,
这样就可以dp了,可以发现因为 i\*(i+1)/2<=k,所以dp预处理的时间是k(sqrt(k))
对于每个f[i][j] 显然他就答案的贡献是
设m为和为k-j,n个元的不定方程解的个数
if (i&1) 贡献=f[i][j]\*m;
else 贡献=-f[i][j]\*m;
通过挡板法可以用组合数求解,这样就完美解决了
但是。。因为组合数的特性(代码中用了阶乘求逆元),山东集训的mod是1e9+7,这边是1w(不是素数,丧心病狂)。所以这边这个题解的代码无法过掉,(那边过了),所以就不要复制了哦,这仅仅是一个神奇的思路,当然,你可以复制我的再帮我手写ex卢卡斯和crt合并过掉此题。。
另外,机房集训队大佬说了一种n lon n 的多项式乘的做法,更加优秀(然而我并没有听懂)
希望会的人也来发题解,谢谢qaq
```cpp
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <queue>
#include <set>
#define ll long long
#define N 100005
#define inf 1000000005
#define mod 1000000007
#include <map>
//注意mod的修改
using namespace std;
int num,fac[N*2],inv[N*2];
inline int read(){
char c=getchar();int tot=1;while ((c<'0'|| c>'9')&&c!='-') c=getchar();if (c=='-'){tot=-1;c=getchar();}
int sum=0;while (c>='0'&&c<='9'){sum=sum*10+c-'0';c=getchar();}return sum*tot;}
int qpow(int a,int b){
int ans=1,tmp=a;
for(;b;b>>=1,tmp=1ll*tmp*tmp%mod)
if(b&1)ans=1ll*ans*tmp%mod;
return ans;
}
inline void wr(int x){if (x<0) {putchar('-');wr(-x);return;}if(x>=10)wr(x/10);putchar(x%10+'0');}
inline void wrn(int x){wr(x);putchar('\n');}
inline void wri(int x){wr(x);putchar(' ');}
int C(int x,int y){
if(x>y)return 0;
return (ll)fac[y]*inv[x]%mod*inv[y-x]%mod;
}
//int nedge,Next[N*2],head[N],to[N*2];
//void add(int a,int b){nedge++;Next[nedge]=head[a];head[a]=nedge;to[nedge]=b;}
int n,m,sum,ans,f[500][N];
int main(){
//freopen("reverse.in","r",stdin);
//freopen("reverse.out","w",stdout);
n=read();m=read();
fac[0]=inv[0]=1;
for(int i=1;i<=n+m;i++)fac[i]=(ll)fac[i-1]*i%mod;
inv[n+m]=qpow(fac[n+m],mod-2);
for(int i=n+m-1;i;i--)inv[i]=(ll)inv[i+1]*(i+1)%mod;
num=sqrt(m*2)+5;
f[0][0]=1;
for (int i=1;i<=num;i++){
for (int j=i;j<=m;j++){
f[i][j]=(f[i][j-i]+f[i-1][j-i])%mod;
if (j>=n+1) f[i][j]=(f[i][j]-f[i-1][j-n-1])%mod;
}
}
for (int i=0;i<=m;i++){
sum=0;
for (int j=0;j<=num;j++){
if (j&1) sum-=f[j][i];
else sum+=f[j][i];
sum=(sum%mod+mod)%mod;
}
ans=(ans+(ll)sum*C(n-1,n+m-i-1))%mod;
}
wrn(ans);
return 0;
}
```