题解 CF434E 【Furukawa Nagisa's Tree】

Styx

2018-08-27 10:14:31

Solution

在这之前先叨叨一句:我是因为这题是绿题才进来的QAQ ![](https://cdn.luogu.com.cn/upload/pic/30696.png) 本来以为是到难得水的E题,结果推了一波结论发现事情并不简单…… 我的天,点分治是绿题qwq???果然洛谷大佬辈出啊qwq 先简述一下题意: 给出一棵树,树上各点都有点权。 定义$S(s,t)$ 为树上$s\rightarrow t$的简单路径的序列。 设$s$位置的点**点权**为$z_0$,他到$t$的路径上的下一个点点权为$z_1$,以此类推……再设一下这条路径的长度为$l$。 然后定义$G(S(s,t))={z_0}*k^{0}+{z_1}*k^{1}+{z_2}*k^{2}+...+{z_{l-1}}*k^{l-1}$。 如果一条路径满足$G(S(s,t))\equiv x(mod y)$ 那么这条边就是好边,反之如果不满足,这就是条坏边。 定义好三元组$(x,y,z)$为$(x,y),(y,z),(x,z)$**都是**好边或坏边的三元组。 求这棵树上有多少对好三元组?(三元组中的点可以重复) 保证y是质数 这题非常的绕…… 所以,首先我们先不考虑好边坏边怎么在时限内求,先把怎么判断三元组弄清楚 思考一下,很显然针对三元组来说,只会有这八种情况(0是坏边,1是好边) ![](https://cdn.luogu.com.cn/upload/pic/30701.png) 首先先考虑直接判断好三元组,你会发现,我们必须要知道三条边的信息才能判断,反正我比较菜这个东西我是想不到$O(nlogn)$的做法,所以要在思考一下别的方法,我们再看坏三元组的特征 坏三元组和好三元组的区别显然是有点连着一条好边和一条坏边,如下图红的点 ![](https://cdn.luogu.com.cn/upload/pic/30706.png) 那么统计一下,假设好的入边为$in_1$,坏的为$in_0$,同样,假设好的出边为$out_1$,坏的为$out_0$,在一番数数之后,我们得出图中各类点连得边的数量为 >in0 out1 2 in1 out0 2 in1 in0 4 out1 out0 4 因为有两个点,所以每个点的贡献为1/2 即一个点会产生$in_0*in_1*2+out0*out1*2+in0*out1+in1*out0$个坏三元组 因为一个坏三元组里有两个这样的点,所以答案要除二 这样我们就把坏三元组的个数弄到了一个点上,这里的复杂度降到了$O(n)$ 显然根据容斥的思想减掉坏三元组就是好三元组的数量 好的,接着开始接过之前的话题——怎么在时限内求出一个点的$in_1$、$in_0$、$out_1$、$out_0$ 因为三元组中的点可以重复,所以每个点都有$n$条出边,$n$条入边。 所以$in_0=n-in_1,out_0=n-out_1$ 至于$in_1$、$out_1$ 这个是可以用点分治搞得 好的先进入数论环节 首先在点分的时候推两个距离,$vu=(prevu*k+a[now])$表示从下到上的距离1,$vd=(prevd+a[now]*k^{deep})$表示从上到下的距离2 ![](https://cdn.luogu.com.cn/upload/pic/30708.png) 那么显然当一条边是好边的时候存在: $vu+vd*k^{len[vu]} \equiv x(mod y)$ 感觉$len[vu]$这个东西比较尴尬,我们移下项 $vd \equiv (x-vu)/k^{len[vu]} (mod y)$ 其中右项就可以边dfs边处理了,相当于减一下再乘个逆元,逆元表可以和幂次表一起打,反正y一定是质数的话一趟费马小定理就可以了。 那么只要存在$vd \equiv (x-vu)/k^{len[vu]} (mod y)$ 就可以组成一条好边 我们可以把所有得到的vd和右边这个值全部sort一下,然后尺取求一下相等的对数,反正就基本能点分出来$in_1$、$out_1$了…… 之后就按上面的方法计算答案了,总三元组个数为$n^3$,答案就是$n^3-ans/2$ 代码如下: ```cpp #include<map> #include<set> #include<queue> #include<cmath> #include<cstdio> #include<string> #include<vector> #include<cstring> #include<iostream> #include<algorithm> using namespace std; long long cnt,vis[100010],fa[100010],sz[100010],in0[100010],in1[100010],out0[100010],out1[100010],tm[100010],inv[100010],a[100010],x,y,k,n,ans; vector<int> g[100010]; struct point { long long val,id; } tmp1[100010],tmp2[100010]; int cmp(point a,point b) { return a.val<b.val; } long long kasumi(long long a,long long b) { long long ans=1; while(b) { if(b&1) { ans=ans*a%y; } a=a*a%y; b>>=1; } return ans; } int get_sz(int now,int f) { sz[now]=1; fa[now]=f; for(int i=0; i<g[now].size(); i++) { if(g[now][i]==f||vis[g[now][i]]) continue; get_sz(g[now][i],now); sz[now]+=sz[g[now][i]]; } } int get_zx(int now,int f) { if(sz[now]==1) return now; int maxson=-1,son; for(int i=0; i<g[now].size(); i++) { if(g[now][i]==f||vis[g[now][i]]) continue; if(maxson<sz[g[now][i]]) { maxson=sz[g[now][i]]; son=g[now][i]; } } int zx=get_zx(son,now); while(sz[zx]<2*(sz[now]-sz[zx])) zx=fa[zx]; return zx; } int get_val(int now,int f,int dep,long long vu,long long vd) { vu=(vu*k+a[now])%y; if(dep) vd=(vd+a[now]*tm[dep-1])%y; tmp1[++cnt].id=now; tmp1[cnt].val=(x-vu+y)*inv[dep+1]%y; tmp2[cnt].id=now; tmp2[cnt].val=vd; for(int i=0; i<g[now].size(); i++) { if(g[now][i]==f||vis[g[now][i]]) continue; get_val(g[now][i],now,dep+1,vu,vd); } } int calc(int now,int kd,int dep,long long vu,long long vd) { int posa,posb,tot; cnt=0; vu=(vu*k+a[now])%y; if(dep) vd=(vd+a[now]*tm[dep-1])%y; tmp1[++cnt].id=now; tmp1[cnt].val=(x-vu+y)*inv[dep+1]%y; tmp2[cnt].id=now; tmp2[cnt].val=vd; for(int i=0; i<g[now].size(); i++) { if(!vis[g[now][i]]) { get_val(g[now][i],now,dep+1,vu,vd); } } sort(tmp1+1,tmp1+cnt+1,cmp); sort(tmp2+1,tmp2+cnt+1,cmp); tot=0; for(posa=posb=1; posa<=cnt; posa++) { while(posb<=cnt&&tmp2[posb].val<=tmp1[posa].val) { if(posb==1||tmp2[posb].val!=tmp2[posb-1].val) tot=0; tot++; posb++; } if(posb!=1&&tmp2[posb-1].val==tmp1[posa].val) out1[tmp1[posa].id]+=tot*kd; } tot=0; for(posa=posb=1; posb<=cnt; posb++) { while(posa<=cnt&&tmp1[posa].val<=tmp2[posb].val) { if(posa==1||tmp1[posa].val!=tmp1[posa-1].val) tot=0; tot++; posa++; } if(posa!=1&&tmp1[posa-1].val==tmp2[posb].val) in1[tmp2[posb].id]+=tot*kd; } } int solve(int now) { vis[now]=1; calc(now,1,0,0,0); for(int i=0; i<g[now].size(); i++) { if(vis[g[now][i]]) continue; calc(g[now][i],-1,1,a[now],0); } for(int i=0 ;i<g[now].size(); i++) { if(vis[g[now][i]]) continue; get_sz(g[now][i],0); int zx=get_zx(g[now][i],0); solve(zx); } } int main() { scanf("%lld%lld%lld%lld",&n,&y,&k,&x); for(int i=1; i<=n; i++) { scanf("%lld",&a[i]); } tm[0]=1; inv[0]=1; long long invk=kasumi(k,y-2); for(int i=1; i<=n; i++) { tm[i]=tm[i-1]*k%y; inv[i]=inv[i-1]*invk%y; } int from,to; for(int i=1; i<n; i++) { scanf("%d%d",&from,&to); g[from].push_back(to); g[to].push_back(from); } get_sz(1,0); int zx=get_zx(1,0); solve(zx); for(int i=1; i<=n; i++) { in0[i]=n-in1[i]; out0[i]=n-out1[i]; ans+=2*in0[i]*in1[i]+2*out0[i]*out1[i]+in1[i]*out0[i]+in0[i]*out1[i]; } printf("%lld\n",1ll*n*n*n-ans/2); } ```