题解 CF809E 【Surprise me!】
Nemlit
2019-09-30 23:46:38
我们要求的柿子是张这样子的:
$$\frac{1}{n * (n - 1)} * \sum_{i = 1}^n\sum_{j = 1}^{n}\phi(a_i*a_j)*dis(i, j)$$
其中$a_i$为一个排列,$dis(i, j)$表示在树上的距离
这种题的套路一般是先拆柿子,但是这道题的式子……
我们要从一个性质下手:
$$\phi(a * b) = \frac{\phi(a) * \phi(b) * gcd(a, b)}{\phi(gcd(a, b))}$$
代入原式得:
$$\frac{1}{n * (n - 1)} * \sum_{i = 1}^n\sum_{j = 1}^{n}\frac{\phi(a_i) * \phi(a_j) * gcd(a_i, a_j)}{\phi(gcd(a_i, a_j))}*dis(i, j)$$
先忽略前面的数,只看后面的$\sum$,枚举$gcd(a_i, a_j)$,得到
$$\sum_{k = 1}^n\frac{k}{\phi(k)}\sum_{i = 1}^n\sum_{j = 1}^{n}\phi(a_i) * \phi(a_j)*dis(i, j)*[gcd(a_i, a_j) == k]$$
然后反演一波,得到:
$$\sum_{k = 1}^n\frac{k}{\phi(k)}\sum_{i = 1}^n\sum_{j = 1}^{n}\phi(a_i) * \phi(a_j)*dis(i, j)*\sum_{(x * k|a[i]) \& (x * k | a[j])}\mu(x)$$
枚举$k * x$
$$\sum_{T = 1}^n\sum_{k|T}\frac{k}{\phi(k)}\sum_{i = 1}^n\sum_{j = 1}^{n}\phi(a_i) * \phi(a_j)*dis(i, j)*\sum_{(T|a[i]) \& (T | a[j])}\mu(\frac{T}{k})$$
交换顺序得:
$$\sum_{T = 1}^n\sum_{k|T}\frac{k}{\phi(k)} * \mu(\frac{T}{k})\sum_{a[i]\ |\ T}\sum_{a[j]\ |\ T}\phi(a_i) * \phi(a_j)*dis(i, j)$$
我们考虑枚举T,对于后面的柿子,我们可以单独拎出来,对所有$a[i] | T$用树形DP求出后面柿子的答案,前面的柿子可以提前与处理出来
由于虚树的总点数是$(nlogn)$个(并不会证明),所以复杂度正确,但由于虚树上的DP和普通DP有一定差异,所以我们还需要对后面的柿子继续化简
$$\sum_{a[i]\ |\ T}\sum_{a[j]\ |\ T}\phi(a_i) * \phi(a_j)*dis(i, j)$$
拆开$dis(i, j)$得:
$$\sum_{a[i]\ |\ T}\sum_{a[j]\ |\ T}\phi(a_i) * \phi(a_j)*(dep[i] + dep[j] - 2 * dep[lca(i, j)])$$
令$val[i] = \phi(a_i)$,把所有$a[i] | T$拎出来,假设有x个
$$\sum_{i= 1}^{x}\sum_{j = 1}^xval[i] * val[j]*(dep[i] + dep[j] - 2 * dep[lca(i, j)])$$
$$\sum_{i= 1}^{x}\sum_{j = 1}^xval[i] * val[j]*dep[i] + \sum_{i= 1}^{x}\sum_{j = 1}^xval[i] * val[j] * dep[j] -2 * \sum_{i= 1}^{x}\sum_{j = 1}^xval[i] * val[j] * dep[lca(i, j)])$$
$$2 * \sum_{i= 1}^{x}val[i] *dep[i] \sum_{j = 1}^xval[j] -2 * \sum_{i= 1}^{x}\sum_{j = 1}^xval[i] * val[j] * dep[lca(i, j)])$$
前面的柿子可以与处理出来,后面的柿子只需要我们在虚树上枚举lca,求出$\sum_{i= 1}^{x}\sum_{j = 1}^xval[i] * val[j]*[lca(i, j) == lca]$
这个值其实不难求,记录$f(x)= \sum_{i = 1}^xval[i]$即可
## $Code:$
```
#include<bits/stdc++.h>
using namespace std;
#define il inline
#define re register
#define mod 1000000007
il int read() {
re int x = 0, f = 1; re char c = getchar();
while(c < '0' || c > '9') { if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
return x * f;
}
#define rep(i, s, t) for(re int i = s; i <= t; ++ i)
#define Next(i, u) for(re int i = head[u]; i; i = e[i].next)
#define mem(k, p) memset(k, p, sizeof(k))
#define maxn 400005
int n, m, Go[maxn], head[maxn], cnt, rev[maxn];
struct edge { int v, next; }e[maxn << 1];
il void add(int u, int v) {
e[++ cnt] = (edge){v, head[u]}, head[u] = cnt;
e[++ cnt] = (edge){u, head[v]}, head[v] = cnt;
}
il int mul(int a, int b) { return 1ll * a * b % mod; }
il int qpow(int a, int b) {
int r = 1;
while(b) {
if(b & 1) r = mul(a, r);
a = mul(a, a), b >>= 1;
}
return r;
}
int prim[maxn], tot, Vis[maxn], phi[maxn], mu[maxn], F[maxn], ans, G[maxn];
il void init(int n) {
mu[1] = phi[1] = 1;
rep(i, 2, n) {
if(!Vis[i]) prim[++ cnt] = i, mu[i] = -1, phi[i] = i - 1;
rep(j, 1, cnt) {
if(i * prim[j] > n) break;
Vis[i * prim[j]] = 1;
if(i % prim[j] == 0) {
phi[i * prim[j]] = phi[i] * prim[j];
break;
}
mu[i * prim[j]] = -mu[i], phi[i * prim[j]] = phi[i] * phi[prim[j]];
}
}
rep(i, 1, n)
for(re int j = i; j <= n; j += i)
F[j] = (F[j] + mul(mul(i, qpow(phi[i], mod - 2)), mu[j / i])) % mod, F[j] = (F[j] + mod) % mod;
}
int fa[maxn], dep[maxn], Top[maxn], dfn[maxn], col, son[maxn], size[maxn];
il void dfs1(int u, int fr) {
size[u] = 1, fa[u] = fr, dep[u] = dep[fr] + 1;
Next(i, u) {
int v = e[i].v;
if(v == fr) continue;
dfs1(v, u), size[u] += size[v];
if(size[v] > size[son[u]]) son[u] = v;
}
}
il void dfs2(int u, int fr) {
dfn[u] = ++ col, Top[u] = fr;
if(son[u]) dfs2(son[u], fr);
Next(i, u) if(e[i].v != fa[u] && e[i].v != son[u]) dfs2(e[i].v, e[i].v);
}
il int LCA(int u, int v) {
while(Top[u] != Top[v]) dep[Top[u]] > dep[Top[v]] ? u = fa[Top[u]] : v = fa[Top[v]];
return dep[u] > dep[v] ? v : u;
}
int st[maxn], top, a[maxn], tmp, pax, vis[maxn], f[maxn], val[maxn], g[maxn];
il bool cmp(int a, int b) { return dfn[a] < dfn[b]; }
il void insert(int x) {
if(top == 1 && x != 1) return (void)(st[++ top] = x);
int lca = LCA(st[top], x);
if(x == lca) return;
while(top > 1 && dep[st[top - 1]] > dep[lca]) {
add(st[top], st[top - 1]), -- top;
}
if(dep[st[top]] > dep[lca]) add(lca, st[top]), -- top;
if(dep[st[top]] < dep[lca]) st[++ top] = lca;
st[++ top] = x;
}
il void build(int n) {
sort(a + 1, a + n + 1, cmp), st[top = 1] = 1;
rep(i, 1, n) insert(a[i]);
while(top > 1) add(st[top - 1], st[top]), -- top;
}
il void get_dis(int u, int fr) {
if(vis[u]) f[u] = mul(phi[Go[u]], dep[u]), val[u] = phi[Go[u]];
int sum = val[u];
Next(i, u) {
int v = e[i].v;
if(v == fr) continue;
get_dis(v, u);
g[u] = (g[u] + mul(val[v], sum)) % mod;
sum = (sum + val[v]) % mod;
f[u] = (f[u] + f[v]) % mod, val[u] = (val[u] + val[v]) % mod;
}
g[u] = mul(g[u], dep[u]);
}
il void dfs_mem(int u, int fr) {
Next(i, u) if(e[i].v != fr) dfs_mem(e[i].v, u);
tmp = (tmp + g[u]) % mod, head[u] = vis[u] = f[u] = val[u] = g[u] = 0;
}
il void solve() {
rep(T, 1, n / 2) {
pax = tmp = cnt = 0;
for(re int i = T; i <= n; i += T) a[++ pax] = rev[i], vis[rev[i]] = 1;
build(pax), get_dis(1, 0);
G[T] = 2ll * mul(f[1], val[1]) % mod;
dfs_mem(1, 0), tmp = mul(2, tmp);
rep(i, 1, pax) tmp = (tmp + mul(dep[a[i]], mul(phi[Go[a[i]]], phi[Go[a[i]]]))) % mod;
G[T] = (G[T] - 2ll * tmp % mod + mod) % mod;
}
}
int main() {
n = read(), init(n);
rep(i, 1, n) Go[i] = read(), rev[Go[i]] = i;
rep(i, 1, n - 1) add(read(), read());
dfs1(1, 0), dfs2(1, 1), mem(head, 0), solve();
rep(i, 1, n) ans = (ans + mul(G[i], F[i])) % mod;
printf("%d", mul((ans + mod) % mod, qpow(mul(n, n - 1), mod - 2)));
return 0;
}
```