题解 CF981H 【K Paths】
da32s1da
2019-04-16 18:32:05
一道不错的分治ntt题目
考虑一个合法的选择方案,由于被覆盖的边要么被覆盖1次,要么被覆盖$\mathrm{K}$次,所以被覆盖$\mathrm{K}$次的边一定**构成一条路径**,且其他的边都在这条路径**端点的子树**中, 且每个子树只能选$0$个或$1$个。
我们给这个无根树钦定一个根,然后枚举路径的两个端点$u,v$。这里我们只考虑点$u$的情况。
很显然,$u$的子树$p$可以不从中选端点,也可以选一个端点,而选的方案数是$size[p]$,所以考虑这样一个式子
$$\prod_{p}(1+size[p]x)$$
这个式子的$k$次方项的系数$a_k$就是从$u$的所有子树中选$k$个端点的方案数。
注意到我们选择的路径是**可重且有先后顺序**的,所以我们还可以从子树中选择$k-1,k-2,\cdots,1,0$个端点,且剩下的只能选$u$。
所以对于点$u$来说,合法的方案数是$f[u]=\sum_{i=0}^kA_k^ia_i$,这里的$A_k^i=\frac{k!}{(k-i)!}$,即排列数。
所以答案是
$$\sum_{u=1}^n\sum_{v=u+1}^nf[u]f[v]=\frac{(\sum_{u=1}^nf[u])^2-(\sum_{u=1}^nf^2[u])}{2}$$
但这显然是错的(逃
因为点$u$如果是点$v$的祖先,就会挂掉。换句话说,$f[u]$是从$u$及其子树中选$k$个端点的方案,那么路径的另外一个端点要保证在$u$的子树外。将$u$是$v$祖先的这部分答案减掉。
现在考虑$u$和其子孙$v$的答案。假如$v$最接近$u$的祖先是$p$,那么答案是
$$f[v]\sum_{i=0}^k A_k^i[x^i]\big(\frac{1+(n-size[u])x}{1+size[p]x}\prod_{q}(1+size[q]x)\big)$$
可以将$p$的子孙$v$一起计算,略微变形,变成
$$\sum_{i=0}^kA_k^i [x^i]\big(\sum_vf[v](1+(n-size[u])x)\prod_{q\neq p}(1+size[q]x)\big)$$
然后就可以大力分治ntt做了。
形式化的,我们令$F(x)=\prod_{i=1}^n(1+a_ix)$,$G(x)=\sum_{i=1}^nb_i\prod_{j=1,j\neq i}^n(1+a_jx)$
那么我们处理完$[\mathrm{l,mid}],[\mathrm{mid+1,r}]$的答案后,令$F=F_l*F_{mid+1},G=F_l * G_{mid+1}+G_l * F_{mid+1}$即可。(此处$\ *\ $代表卷积)
复杂度分析:由于每一个点只能使其父亲的多项式次数加$1$,所以复杂度是$O(n\log^2n)$
```cpp
#include<cstdio>
#include<vector>
using namespace std;
const int mod=998244353;
const int N=262150;
typedef pair<int,int> pr;
int n,K,x,y,Ans;
int siz[N],f[N],fs[N],fac[N],inv[N];
int rnk[N],f1[N],f2[N],g1[N],g2[N],F[N],G[N];
vector<int>vp1[N],vp2[N];
vector<pr>vec[N];
int to[N],las[N],fir[N],cnt;
inline void add_edge(int u,int v){
to[++cnt]=v;las[cnt]=fir[u];fir[u]=cnt;
to[++cnt]=u;las[cnt]=fir[v];fir[v]=cnt;
}
int ksm(int u,int v){
int res=1;
for(;v;v>>=1,u=1ll*u*u%mod)
if(v&1)res=1ll*res*u%mod;
return res;
}
void pre(int u){//预处理阶乘和逆元
fac[0]=fac[1]=inv[0]=inv[1]=1;
for(int i=2;i<=u;i++)fac[i]=1ll*fac[i-1]*i%mod;
inv[u]=ksm(fac[u],mod-2);
for(int i=u-1;i>=2;i--)inv[i]=1ll*inv[i+1]*(i+1)%mod;
}
inline int _(int u){return u<mod?u:u-mod;}
inline int __(int u){return u<0?u+mod:u;}
void ntt(int *t,int len,int opt){
int g=3,g_=ksm(g,mod-2);
for(int i=0;i<len;i++)if(i<rnk[i])swap(t[i],t[rnk[i]]);
for(int i=1;i<len;i<<=1){
int wn=ksm(~opt?g:g_,(mod-1)/(i<<1));
for(int j=0,J=i<<1;j<len;j+=J){
int w=1;
for(int k=j;k<i+j;k++,w=1ll*w*wn%mod){
int r=1ll*t[i+k]*w%mod;
t[i+k]=__(t[k]-r);
t[k]=_(t[k]+r);
}
}
}
if(~opt)return;
int ny=ksm(len,mod-2);
for(int i=0;i<len;i++)t[i]=1ll*t[i]*ny%mod;
}
void solve(int l,int r,int c){
if(l==r){
vp1[l].resize(2);
vp1[l][0]=1;
vp1[l][1]=vec[c][l].second;
vp2[l].resize(2);
vp2[l][0]=_(f[vec[c][l].first]+fs[vec[c][l].first]);
vp2[l][1]=1ll*(f[vec[c][l].first]+fs[vec[c][l].first])*(n-siz[c])%mod;
//注意乘\sum f(v)
return;
}
int m=l+r>>1;
solve(l,m,c);solve(m+1,r,c);m++;
for(int i=0;i<vp1[l].size();i++)f1[i]=vp1[l][i];
for(int i=0;i<vp1[m].size();i++)f2[i]=vp1[m][i];
for(int i=0;i<vp2[l].size();i++)g1[i]=vp2[l][i];
for(int i=0;i<vp2[m].size();i++)g2[i]=vp2[m][i];
int len=1,_2=-1,Len=vp1[l].size()+vp1[m].size()-1;
while(len<Len)len<<=1,_2++;
for(int i=0;i<len;i++)rnk[i]=(rnk[i>>1]>>1)|((i&1)<<_2);
for(int i=vp1[l].size();i<len;i++)f1[i]=0;
for(int i=vp1[m].size();i<len;i++)f2[i]=0;
for(int i=vp2[l].size();i<len;i++)g1[i]=0;
for(int i=vp2[m].size();i<len;i++)g2[i]=0;
ntt(f1,len,1);ntt(f2,len,1);ntt(g1,len,1);ntt(g2,len,1);
for(int i=0;i<len;i++)F[i]=1ll*f1[i]*f2[i]%mod;// F=F_l*F_{mid+1}
for(int i=0;i<len;i++)G[i]=_(1ll*f1[i]*g2[i]%mod+1ll*f2[i]*g1[i]%mod);// G=F_l*G_{mid+1}+G_l*F_{mid+1}
ntt(F,len,-1);ntt(G,len,-1);
vp1[l].resize(Len);vp2[l].resize(Len);
for(int i=0;i<Len;i++)vp1[l][i]=F[i];
for(int i=0;i<Len;i++)vp2[l][i]=G[i];
}
void dfs(int u,int v){
siz[u]=1;
for(int i=fir[u];i;i=las[i])
if(to[i]!=v){
dfs(to[i],u);siz[u]+=siz[to[i]];
vec[u].push_back(pr(to[i],siz[to[i]]));
fs[u]=_(fs[u]+_(fs[to[i]]+f[to[i]]));
}
if(!vec[u].size())vec[u].push_back(pr(0,0));//细节
solve(0,vec[u].size()-1,u);
for(int i=0;i<=min(K,(int)vp1[0].size()-1);i++)
f[u]=_(f[u]+1ll*vp1[0][i]*fac[K]%mod*inv[K-i]%mod);
x=_(x+f[u]);y=_(y+1ll*f[u]*f[u]%mod);
Ans=__(Ans-1ll*f[u]*fs[u]%mod);//减掉祖孙关系的答案
for(int i=0;i<=min(K,(int)vp2[0].size()-1);i++)
Ans=_(Ans+1ll*vp2[0][i]*fac[K]%mod*inv[K-i]%mod);
}
int main(){
scanf("%d%d",&n,&K);
if(K==1){printf("%d\n",1ll*n*(n-1)%mod*ksm(2,mod-2)%mod);return 0;}//特判K=1
for(int i=1;i<n;i++){
scanf("%d%d",&x,&y);
add_edge(x,y);
}
pre(K);x=y=0;
dfs(1,1);
printf("%d\n",_(Ans+1ll*__(1ll*x*x%mod-y)*ksm(2,mod-2)%mod));
}
```