题解 CF981H 【K Paths】

da32s1da

2019-04-16 18:32:05

Solution

一道不错的分治ntt题目 考虑一个合法的选择方案,由于被覆盖的边要么被覆盖1次,要么被覆盖$\mathrm{K}$次,所以被覆盖$\mathrm{K}$次的边一定**构成一条路径**,且其他的边都在这条路径**端点的子树**中, 且每个子树只能选$0$个或$1$个。 我们给这个无根树钦定一个根,然后枚举路径的两个端点$u,v$。这里我们只考虑点$u$的情况。 很显然,$u$的子树$p$可以不从中选端点,也可以选一个端点,而选的方案数是$size[p]$,所以考虑这样一个式子 $$\prod_{p}(1+size[p]x)$$ 这个式子的$k$次方项的系数$a_k$就是从$u$的所有子树中选$k$个端点的方案数。 注意到我们选择的路径是**可重且有先后顺序**的,所以我们还可以从子树中选择$k-1,k-2,\cdots,1,0$个端点,且剩下的只能选$u$。 所以对于点$u$来说,合法的方案数是$f[u]=\sum_{i=0}^kA_k^ia_i$,这里的$A_k^i=\frac{k!}{(k-i)!}$,即排列数。 所以答案是 $$\sum_{u=1}^n\sum_{v=u+1}^nf[u]f[v]=\frac{(\sum_{u=1}^nf[u])^2-(\sum_{u=1}^nf^2[u])}{2}$$ 但这显然是错的(逃 因为点$u$如果是点$v$的祖先,就会挂掉。换句话说,$f[u]$是从$u$及其子树中选$k$个端点的方案,那么路径的另外一个端点要保证在$u$的子树外。将$u$是$v$祖先的这部分答案减掉。 现在考虑$u$和其子孙$v$的答案。假如$v$最接近$u$的祖先是$p$,那么答案是 $$f[v]\sum_{i=0}^k A_k^i[x^i]\big(\frac{1+(n-size[u])x}{1+size[p]x}\prod_{q}(1+size[q]x)\big)$$ 可以将$p$的子孙$v$一起计算,略微变形,变成 $$\sum_{i=0}^kA_k^i [x^i]\big(\sum_vf[v](1+(n-size[u])x)\prod_{q\neq p}(1+size[q]x)\big)$$ 然后就可以大力分治ntt做了。 形式化的,我们令$F(x)=\prod_{i=1}^n(1+a_ix)$,$G(x)=\sum_{i=1}^nb_i\prod_{j=1,j\neq i}^n(1+a_jx)$ 那么我们处理完$[\mathrm{l,mid}],[\mathrm{mid+1,r}]$的答案后,令$F=F_l*F_{mid+1},G=F_l * G_{mid+1}+G_l * F_{mid+1}$即可。(此处$\ *\ $代表卷积) 复杂度分析:由于每一个点只能使其父亲的多项式次数加$1$,所以复杂度是$O(n\log^2n)$ ```cpp #include<cstdio> #include<vector> using namespace std; const int mod=998244353; const int N=262150; typedef pair<int,int> pr; int n,K,x,y,Ans; int siz[N],f[N],fs[N],fac[N],inv[N]; int rnk[N],f1[N],f2[N],g1[N],g2[N],F[N],G[N]; vector<int>vp1[N],vp2[N]; vector<pr>vec[N]; int to[N],las[N],fir[N],cnt; inline void add_edge(int u,int v){ to[++cnt]=v;las[cnt]=fir[u];fir[u]=cnt; to[++cnt]=u;las[cnt]=fir[v];fir[v]=cnt; } int ksm(int u,int v){ int res=1; for(;v;v>>=1,u=1ll*u*u%mod) if(v&1)res=1ll*res*u%mod; return res; } void pre(int u){//预处理阶乘和逆元 fac[0]=fac[1]=inv[0]=inv[1]=1; for(int i=2;i<=u;i++)fac[i]=1ll*fac[i-1]*i%mod; inv[u]=ksm(fac[u],mod-2); for(int i=u-1;i>=2;i--)inv[i]=1ll*inv[i+1]*(i+1)%mod; } inline int _(int u){return u<mod?u:u-mod;} inline int __(int u){return u<0?u+mod:u;} void ntt(int *t,int len,int opt){ int g=3,g_=ksm(g,mod-2); for(int i=0;i<len;i++)if(i<rnk[i])swap(t[i],t[rnk[i]]); for(int i=1;i<len;i<<=1){ int wn=ksm(~opt?g:g_,(mod-1)/(i<<1)); for(int j=0,J=i<<1;j<len;j+=J){ int w=1; for(int k=j;k<i+j;k++,w=1ll*w*wn%mod){ int r=1ll*t[i+k]*w%mod; t[i+k]=__(t[k]-r); t[k]=_(t[k]+r); } } } if(~opt)return; int ny=ksm(len,mod-2); for(int i=0;i<len;i++)t[i]=1ll*t[i]*ny%mod; } void solve(int l,int r,int c){ if(l==r){ vp1[l].resize(2); vp1[l][0]=1; vp1[l][1]=vec[c][l].second; vp2[l].resize(2); vp2[l][0]=_(f[vec[c][l].first]+fs[vec[c][l].first]); vp2[l][1]=1ll*(f[vec[c][l].first]+fs[vec[c][l].first])*(n-siz[c])%mod; //注意乘\sum f(v) return; } int m=l+r>>1; solve(l,m,c);solve(m+1,r,c);m++; for(int i=0;i<vp1[l].size();i++)f1[i]=vp1[l][i]; for(int i=0;i<vp1[m].size();i++)f2[i]=vp1[m][i]; for(int i=0;i<vp2[l].size();i++)g1[i]=vp2[l][i]; for(int i=0;i<vp2[m].size();i++)g2[i]=vp2[m][i]; int len=1,_2=-1,Len=vp1[l].size()+vp1[m].size()-1; while(len<Len)len<<=1,_2++; for(int i=0;i<len;i++)rnk[i]=(rnk[i>>1]>>1)|((i&1)<<_2); for(int i=vp1[l].size();i<len;i++)f1[i]=0; for(int i=vp1[m].size();i<len;i++)f2[i]=0; for(int i=vp2[l].size();i<len;i++)g1[i]=0; for(int i=vp2[m].size();i<len;i++)g2[i]=0; ntt(f1,len,1);ntt(f2,len,1);ntt(g1,len,1);ntt(g2,len,1); for(int i=0;i<len;i++)F[i]=1ll*f1[i]*f2[i]%mod;// F=F_l*F_{mid+1} for(int i=0;i<len;i++)G[i]=_(1ll*f1[i]*g2[i]%mod+1ll*f2[i]*g1[i]%mod);// G=F_l*G_{mid+1}+G_l*F_{mid+1} ntt(F,len,-1);ntt(G,len,-1); vp1[l].resize(Len);vp2[l].resize(Len); for(int i=0;i<Len;i++)vp1[l][i]=F[i]; for(int i=0;i<Len;i++)vp2[l][i]=G[i]; } void dfs(int u,int v){ siz[u]=1; for(int i=fir[u];i;i=las[i]) if(to[i]!=v){ dfs(to[i],u);siz[u]+=siz[to[i]]; vec[u].push_back(pr(to[i],siz[to[i]])); fs[u]=_(fs[u]+_(fs[to[i]]+f[to[i]])); } if(!vec[u].size())vec[u].push_back(pr(0,0));//细节 solve(0,vec[u].size()-1,u); for(int i=0;i<=min(K,(int)vp1[0].size()-1);i++) f[u]=_(f[u]+1ll*vp1[0][i]*fac[K]%mod*inv[K-i]%mod); x=_(x+f[u]);y=_(y+1ll*f[u]*f[u]%mod); Ans=__(Ans-1ll*f[u]*fs[u]%mod);//减掉祖孙关系的答案 for(int i=0;i<=min(K,(int)vp2[0].size()-1);i++) Ans=_(Ans+1ll*vp2[0][i]*fac[K]%mod*inv[K-i]%mod); } int main(){ scanf("%d%d",&n,&K); if(K==1){printf("%d\n",1ll*n*(n-1)%mod*ksm(2,mod-2)%mod);return 0;}//特判K=1 for(int i=1;i<n;i++){ scanf("%d%d",&x,&y); add_edge(x,y); } pre(K);x=y=0; dfs(1,1); printf("%d\n",_(Ans+1ll*__(1ll*x*x%mod-y)*ksm(2,mod-2)%mod)); } ```