题解 P4002 【生成树计数】

Jμdge

2019-04-22 17:41:39

Solution

为什么这些大佬讲用 prufer 算贡献的时候都不带解释的说...是太显然了么? 结果小蒟蒻我两篇题解上翻下翻盯了一个中午才看出来是个什么鬼...我太菜了emmm 总而言之就是列出答案的式子然后各种推(多项式题的套路?) ### 但在此之前得先会 Prufer 序列丫! # Prufer 序列 长话短说、简而言之、说到底就是: 一棵无根树的 Prufer 序列构成方法就是每次取出**编号最小**的**叶子**,然后把与其相邻的唯一那个点的编号记录进**序列尾部**,直到只剩两个节点,这样生成的序列长度为 n-2 对于一棵无根树,它唯一对应了一个 Prufer 序列,一个 Prufer 序列也唯一对应了一棵树,也就是说树与 Prufer 序列之间呈双射(就是一一对应)的关系 我们可以用 Prufer 证出一个完全图的生成树个数为 $n^{n-2}$ 因为我们考虑每个生成树都敌营一个 Prufer 序列,而 n-2 长度的 Prufer 序列每个位置都可以放置 1~n 的数,于是 Prufer 序列的个数为 $n^{n-2}$ ,于是对应的生成树个数就是 $n^{n-2} $ 了 (为什么我突然想到了一个 $n^n$ 的算法 【滑稽) ![](https://www.cnblogs.com/images/cnblogs_com/Judge/1264714/o_03C080BF92B6752210DE61CE9B0A22AA.png) ------------------ # 推导过程 有了前置芝士这道题就好做一点了...吧? 我们考虑计算每棵树对应的 Prufer 序列的贡献...我们令 $d_i$ 表示当前这棵树的 Prufer 序列中编号 i 出现的次数,那么每个节点的度数就是 $d_i+1$ (被某状态下的叶子结点连接到的次数和自己连出去的一次) $$ANS=\sum_{\sum d_i=n-2}{(n-2)!\over \prod d_i!} \prod_{i=1}^n {a_i}^{d_i+1}(d_i+1)^m \sum_{j=1}^n (d_j+1)^m$$ 这个式子前面的一坨阶乘就是说:对于一个 Prufer 序列: d ,我们考虑 $(n-2)!$ 排列,然后一个编号 $i$ 出现了 $d_i$ 次,他们的相对顺序对 Prufer 序列是没影响的,要除去贡献 然后对于后面 $d_i+1$ 其实已经解释过了,就是考虑一个点的度数为其在 Prufer 序列中的出现次数加上 1 然后我们就好开始推导了: $$\begin{aligned}ANS=&\sum_{\sum d_i=n-2}{(n-2)!\over \prod d_i!} \prod_{i=1}^n {a_i}^{d_i+1}(d_i+1)^m \sum_{j=1}^n (d_j+1)^m \\ =& (n-2)! \prod_{i=1}^n a_i \sum_{\sum d_i=n-2} \prod_{i=1}^{n} {a_i^{d_i}\over d_i! } (d_i+1)^m \sum_{j=1}^n (d_j+1)^m\\ =& (n-2)! \prod_{i=1}^n a_i \sum_{\sum d_i=n-2} (d_i+1)^m \sum_{i=1}^n {a_i^{d_i} \over d_i!}(d_i+1)^{2m} \prod_{j=1,j \not= i}^{n} {a_j^{d_j}\over d_j! }(d_i+1)^m \end{aligned}$$ 然后我们发现前面的两个表达式可以 $O(n)$ 求,所以主要任务就是求出后面这个东西: $$\sum_{\sum d_i=n-2} (d_i+1)^m \sum_{i=1}^n {a_i^{d_i} \over d_i!}(d_i+1)^{2m} \prod_{j=1,j \not= i}^{n} {a_j^{d_j}\over d_j! }(d_i+1)^m $$ 这个时候我们肯定没办法再直接推了,那么我们引入重头戏:生成函数 我们首先构造两个生成函数: $$\begin{aligned} A(x)&=\sum_{i=0}^\infty{x^i (i+1)^{m}\over i!} \\B(x)&=\sum_{i=1}^\infty{x^i(i+1)^{2m}\over i!}\end{aligned}$$ 然后我们令: $$F(x)=\sum_{i=1}^nA(a_ix) \prod_{j=1,j\not=i}^n B(a_jx)$$ 那么这个生成函数的第 n-2 项就是我们要求的式子的答案辣! $$\begin{aligned}F(x)=&\sum_{i=1}^nB(a_ix) \prod_{j=1,j\not=i}^n A(a_jx) \\=& \sum_{i=1}^n{B(a_ix)\over A(a_ix)} \prod_{j=1}^n A(a_jx) \\=&\sum_{i=1}^n {B(a_ix)\over A(a_ix)} Exp(\sum_{j=1}^n Ln( A(a_jx)) ) \end{aligned}$$ 上面这个 $\prod$ 取 Ln 转 $\sum$ 再 Exp 回来简直是套路? 那么新的问题来了,我们求出 $Ln~ A$ 和 ${B\over A}$ 的系数后,还要将第 k (k=1~n)项乘上 $\sum_{i=1}^{n} a_i^k$ 那么我们考虑如何快速求出这个东西... ----- # 数列的幂和函数 讲道理,如果说只是幂和函数的话,我们可以花样求: >拉格朗日插值 >第二类斯特林数 >伯努利数 但是这里有个数列... 其实数列的话拉格朗日插值照样可以用,只不过这个复杂度嘛... $$k^2,k=n$$ !!!玩个鸡儿... 那么我们只能考虑奇技淫巧了0.0 (和[这道题](https://www.luogu.org/blog/yestoday/solution-p4705)做法蛮像的) 首先我们构造一下生成函数: $$f_j(x)= \sum_{i=1}^k a_j^ix^i $$ 那么我们要求的数列幂和函数就是: $$G(x)=\sum_{j=1}^{n} f_j(x)$$ 其中 $G(x)$ 的第 k 项系数就对应了 $\sum_{i=1}^n a_i^k$ 那么我们把 $f_j(x)$ 的上界 k 变为 $\infty$ ,那么原来的东西可以变成: $$f_j(x)={1\over 1-a_ix}$$ $$G(x)=\sum_{i=1}^n{1\over 1-a_ix}$$ 然后我们考虑 $ln(x)'={x'\over x}$ ,那么我们令: $$g_i(x)=ln(1-a_ix)'={-a_i\over 1-a_ix}$$ 那么我们就发现 : $$g_i(x)=-\sum_{i=1}^k a_j^{i+1}x^i=-a_if_i(x)$$ 于是: $$G(x)=\sum_{i=1}^ng_i(x)x+n$$ 然后我们考虑: $$\begin{aligned}\sum_{i=1}^{n}G_i(x)=&\sum_{i=1}^nln(1-a_ix)' \\ =& ln(\prod_{i=1}^n1-a_ix)' \end{aligned}$$ 那么我们现在的任务就轻松很多了,只需要对于后面的多项式求个积... 然鹅暴力合并是 n^2 的说... 于是考虑一种与多点求值类似的做法(分治)? (复杂度大概是 $n~log~n$ ? 复杂度白痴傻傻算不清楚) 那么我们这样求出的 $G(x)$ 的每一项就是我们需要的东西辣~ ------- 然后愉快的敲代码?(板子抄过来QWQ) # code 100 plus 的代码(然鹅这里并没有到 $100^+$ ) ```cpp //by Judge #include<bits/stdc++.h> #define Rg register #define fp(i,a,b) for(Rg int i=(a),I=(b)+1;i<I;++i) #define fd(i,a,b) for(Rg int i=(a),I=(b)-1;i>I;--i) #define ll long long using namespace std; const int mod=998244353; const int iG=332748118; const int M=1e5+3; typedef int arr[M]; #define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++) char buf[1<<21],*p1=buf,*p2=buf; inline int inc(int x,int y){return (x+y)%mod;} inline int mul(int x,int y){return 1ll*x*y%mod;} inline int dec(int x,int y){return (x-y+mod)%mod;} inline int read(){ int x=0,f=1; char c=getchar(); for(;!isdigit(c);c=getchar()) if(c=='-') f=-1; for(;isdigit(c);c=getchar()) x=x*10+c-'0'; return x*f; } int n,m,res,limit,fac[M],finv[M]; arr siz,D[19],A,B,C,G,sum,r; inline int qpow(Rg int x,Rg int p=mod-2,int s=1){ for(;p;p>>=1,x=mul(x,x)) if(p&1) s=mul(s,x); return s; } inline void init(int n){ int l=-1; for(limit=1;limit<n;limit<<=1)++l; fp(i,0,limit-1) r[i]=(r[i>>1]>>1)|((i&1)<<l); } inline void NTT(int* a,int tp){ fp(i,0,limit-1) if(i<r[i]) swap(a[i],a[r[i]]); for(Rg int mid=1;mid<limit;mid<<=1){ int Gn=qpow(tp?3:iG,(mod-1)/(mid<<1)); for(Rg int j=0,I=mid<<1,x,y;j<limit;j+=I) for(Rg int k=0,g=1;k<mid;++k,g=mul(g,Gn)) x=a[j+k],y=mul(a[j+k+mid],g), a[j+k]=(x+y)%mod,a[j+k+mid]=(x-y+mod)%mod; } if(tp) return; int inv=qpow(limit); fp(i,0,limit-1) a[i]=mul(a[i],inv); } void Inv(int* a,int* b,int n){ static arr C,D; if(n==1) return b[0]=qpow(a[0]),void(); Inv(a,b,n>>1),init(n<<1); fp(i,0,n-1) C[i]=a[i],D[i]=b[i]; fp(i,n,limit-1) C[i]=D[i]=0; NTT(C,1),NTT(D,1); fp(i,0,limit-1) C[i]=mul(C[i],mul(D[i],D[i])); NTT(C,0); fp(i,n,limit-1) b[i]=0; fp(i,0,n-1) b[i]=dec(inc(b[i],b[i]),C[i]); } inline void Direv(int* a,int* b,int n){ fp(i,1,n-1) b[i-1]=mul(a[i],i); b[n-1]=0; } inline void Inter(int* a,int* b,int n){ fp(i,1,n-1) b[i]=mul(a[i-1],qpow(i)); b[0]=0; } inline void Ln(int* a,int* b,int n){ static arr C,D; Inv(a,C,n),Direv(a,D,n),init(n<<1); fp(i,n,limit-1) C[i]=D[i]=0; NTT(C,1),NTT(D,1); fp(i,0,limit-1) C[i]=mul(C[i],D[i]); NTT(C,0),Inter(C,b,n); } inline void Exp(int* a,int* b,int n){ if(n==1) return b[0]=1,void(); static arr B; Exp(a,b,n>>1),Ln(b,B,n),B[0]=dec(a[0]+1,B[0]); fp(i,1,n-1) B[i]=dec(a[i],B[i]); init(n<<1); NTT(B,1),NTT(b,1); fp(i,0,limit-1) b[i]=mul(b[i],B[i]); NTT(b,0); fp(i,n,limit-1) b[i]=B[i]=0; } void solv(int l,int r,int d){ if(l==r) return D[d][0]=1,D[d][1]=mod-siz[l],void(); int mid=(l+r)>>1; solv(l,mid,d),solv(mid+1,r,d+1); init(r-l+2); fp(i,mid-l+2,limit-1) D[d][i]=0; fp(i,r-mid+1,limit-1) D[d+1][i]=0; NTT(D[d],1),NTT(D[d+1],1); fp(i,0,limit-1) D[d][i]=mul(D[d][i],D[d+1][i]); NTT(D[d],0); } int main(){ n=read(),m=read(); if(n==1) return !puts("1"); //////////////////// prep fac and ifac fac[0]=finv[0]=finv[1]=1; fp(i,1,n) fac[i]=mul(fac[i-1],i); fp(i,2,n) finv[i]=mul(mod-mod/i,finv[mod%i]); fp(i,2,n) finv[i]=mul(finv[i-1],finv[i]); /////////////////// prep k=1~n, ∑ai^k fp(i,1,n) siz[i]=read(); solv(1,n,0); fp(i,0,n) G[i]=D[0][i]; int len; for(len=1;len<=n;len<<=1); Ln(G,sum,len); fp(i,1,n) sum[i]=mod-mul(sum[i],i); sum[0]=n; /////////////////// calc ANS fp(i,0,n-1) A[i]=mul(qpow(i+1,m),finv[i]), B[i]=mul(qpow(i+1,m<<1),finv[i]); Ln(A,G,len),Inv(A,C,len),init(len<<1),NTT(C,1),NTT(B,1); fp(i,0,(len<<1)-1) B[i]=mul(B[i],C[i]); NTT(B,0); memset(A,0,sizeof A),memset(C,0,sizeof C); fp(i,0,n-1) B[i]=mul(B[i],sum[i]),A[i]=mul(G[i],sum[i]); //带入求值 Exp(A,C,len),init(len<<1); fp(i,n,(len<<1)-1) B[i]=0; // A 要 Exp 回来 NTT(B,1),NTT(C,1); fp(i,0,limit-1) B[i]=mul(B[i],C[i]); NTT(B,0); res=mul(B[n-2],fac[n-2]); fp(i,1,n) res=mul(res,siz[i]); //乘完之后每一项乘入 res return !printf("%d\n",res); } ``` # UPD AT: 2019.5.9