题解 CF809E 【Surprise me!】

Nemlit

2019-09-30 23:46:38

Solution

我们要求的柿子是张这样子的: $$\frac{1}{n * (n - 1)} * \sum_{i = 1}^n\sum_{j = 1}^{n}\phi(a_i*a_j)*dis(i, j)$$ 其中$a_i$为一个排列,$dis(i, j)$表示在树上的距离 这种题的套路一般是先拆柿子,但是这道题的式子…… 我们要从一个性质下手: $$\phi(a * b) = \frac{\phi(a) * \phi(b) * gcd(a, b)}{\phi(gcd(a, b))}$$ 代入原式得: $$\frac{1}{n * (n - 1)} * \sum_{i = 1}^n\sum_{j = 1}^{n}\frac{\phi(a_i) * \phi(a_j) * gcd(a_i, a_j)}{\phi(gcd(a_i, a_j))}*dis(i, j)$$ 先忽略前面的数,只看后面的$\sum$,枚举$gcd(a_i, a_j)$,得到 $$\sum_{k = 1}^n\frac{k}{\phi(k)}\sum_{i = 1}^n\sum_{j = 1}^{n}\phi(a_i) * \phi(a_j)*dis(i, j)*[gcd(a_i, a_j) == k]$$ 然后反演一波,得到: $$\sum_{k = 1}^n\frac{k}{\phi(k)}\sum_{i = 1}^n\sum_{j = 1}^{n}\phi(a_i) * \phi(a_j)*dis(i, j)*\sum_{(x * k|a[i]) \& (x * k | a[j])}\mu(x)$$ 枚举$k * x$ $$\sum_{T = 1}^n\sum_{k|T}\frac{k}{\phi(k)}\sum_{i = 1}^n\sum_{j = 1}^{n}\phi(a_i) * \phi(a_j)*dis(i, j)*\sum_{(T|a[i]) \& (T | a[j])}\mu(\frac{T}{k})$$ 交换顺序得: $$\sum_{T = 1}^n\sum_{k|T}\frac{k}{\phi(k)} * \mu(\frac{T}{k})\sum_{a[i]\ |\ T}\sum_{a[j]\ |\ T}\phi(a_i) * \phi(a_j)*dis(i, j)$$ 我们考虑枚举T,对于后面的柿子,我们可以单独拎出来,对所有$a[i] | T$用树形DP求出后面柿子的答案,前面的柿子可以提前与处理出来 由于虚树的总点数是$(nlogn)$个(并不会证明),所以复杂度正确,但由于虚树上的DP和普通DP有一定差异,所以我们还需要对后面的柿子继续化简 $$\sum_{a[i]\ |\ T}\sum_{a[j]\ |\ T}\phi(a_i) * \phi(a_j)*dis(i, j)$$ 拆开$dis(i, j)$得: $$\sum_{a[i]\ |\ T}\sum_{a[j]\ |\ T}\phi(a_i) * \phi(a_j)*(dep[i] + dep[j] - 2 * dep[lca(i, j)])$$ 令$val[i] = \phi(a_i)$,把所有$a[i] | T$拎出来,假设有x个 $$\sum_{i= 1}^{x}\sum_{j = 1}^xval[i] * val[j]*(dep[i] + dep[j] - 2 * dep[lca(i, j)])$$ $$\sum_{i= 1}^{x}\sum_{j = 1}^xval[i] * val[j]*dep[i] + \sum_{i= 1}^{x}\sum_{j = 1}^xval[i] * val[j] * dep[j] -2 * \sum_{i= 1}^{x}\sum_{j = 1}^xval[i] * val[j] * dep[lca(i, j)])$$ $$2 * \sum_{i= 1}^{x}val[i] *dep[i] \sum_{j = 1}^xval[j] -2 * \sum_{i= 1}^{x}\sum_{j = 1}^xval[i] * val[j] * dep[lca(i, j)])$$ 前面的柿子可以与处理出来,后面的柿子只需要我们在虚树上枚举lca,求出$\sum_{i= 1}^{x}\sum_{j = 1}^xval[i] * val[j]*[lca(i, j) == lca]$ 这个值其实不难求,记录$f(x)= \sum_{i = 1}^xval[i]$即可 ## $Code:$ ``` #include<bits/stdc++.h> using namespace std; #define il inline #define re register #define mod 1000000007 il int read() { re int x = 0, f = 1; re char c = getchar(); while(c < '0' || c > '9') { if(c == '-') f = -1; c = getchar();} while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar(); return x * f; } #define rep(i, s, t) for(re int i = s; i <= t; ++ i) #define Next(i, u) for(re int i = head[u]; i; i = e[i].next) #define mem(k, p) memset(k, p, sizeof(k)) #define maxn 400005 int n, m, Go[maxn], head[maxn], cnt, rev[maxn]; struct edge { int v, next; }e[maxn << 1]; il void add(int u, int v) { e[++ cnt] = (edge){v, head[u]}, head[u] = cnt; e[++ cnt] = (edge){u, head[v]}, head[v] = cnt; } il int mul(int a, int b) { return 1ll * a * b % mod; } il int qpow(int a, int b) { int r = 1; while(b) { if(b & 1) r = mul(a, r); a = mul(a, a), b >>= 1; } return r; } int prim[maxn], tot, Vis[maxn], phi[maxn], mu[maxn], F[maxn], ans, G[maxn]; il void init(int n) { mu[1] = phi[1] = 1; rep(i, 2, n) { if(!Vis[i]) prim[++ cnt] = i, mu[i] = -1, phi[i] = i - 1; rep(j, 1, cnt) { if(i * prim[j] > n) break; Vis[i * prim[j]] = 1; if(i % prim[j] == 0) { phi[i * prim[j]] = phi[i] * prim[j]; break; } mu[i * prim[j]] = -mu[i], phi[i * prim[j]] = phi[i] * phi[prim[j]]; } } rep(i, 1, n) for(re int j = i; j <= n; j += i) F[j] = (F[j] + mul(mul(i, qpow(phi[i], mod - 2)), mu[j / i])) % mod, F[j] = (F[j] + mod) % mod; } int fa[maxn], dep[maxn], Top[maxn], dfn[maxn], col, son[maxn], size[maxn]; il void dfs1(int u, int fr) { size[u] = 1, fa[u] = fr, dep[u] = dep[fr] + 1; Next(i, u) { int v = e[i].v; if(v == fr) continue; dfs1(v, u), size[u] += size[v]; if(size[v] > size[son[u]]) son[u] = v; } } il void dfs2(int u, int fr) { dfn[u] = ++ col, Top[u] = fr; if(son[u]) dfs2(son[u], fr); Next(i, u) if(e[i].v != fa[u] && e[i].v != son[u]) dfs2(e[i].v, e[i].v); } il int LCA(int u, int v) { while(Top[u] != Top[v]) dep[Top[u]] > dep[Top[v]] ? u = fa[Top[u]] : v = fa[Top[v]]; return dep[u] > dep[v] ? v : u; } int st[maxn], top, a[maxn], tmp, pax, vis[maxn], f[maxn], val[maxn], g[maxn]; il bool cmp(int a, int b) { return dfn[a] < dfn[b]; } il void insert(int x) { if(top == 1 && x != 1) return (void)(st[++ top] = x); int lca = LCA(st[top], x); if(x == lca) return; while(top > 1 && dep[st[top - 1]] > dep[lca]) { add(st[top], st[top - 1]), -- top; } if(dep[st[top]] > dep[lca]) add(lca, st[top]), -- top; if(dep[st[top]] < dep[lca]) st[++ top] = lca; st[++ top] = x; } il void build(int n) { sort(a + 1, a + n + 1, cmp), st[top = 1] = 1; rep(i, 1, n) insert(a[i]); while(top > 1) add(st[top - 1], st[top]), -- top; } il void get_dis(int u, int fr) { if(vis[u]) f[u] = mul(phi[Go[u]], dep[u]), val[u] = phi[Go[u]]; int sum = val[u]; Next(i, u) { int v = e[i].v; if(v == fr) continue; get_dis(v, u); g[u] = (g[u] + mul(val[v], sum)) % mod; sum = (sum + val[v]) % mod; f[u] = (f[u] + f[v]) % mod, val[u] = (val[u] + val[v]) % mod; } g[u] = mul(g[u], dep[u]); } il void dfs_mem(int u, int fr) { Next(i, u) if(e[i].v != fr) dfs_mem(e[i].v, u); tmp = (tmp + g[u]) % mod, head[u] = vis[u] = f[u] = val[u] = g[u] = 0; } il void solve() { rep(T, 1, n / 2) { pax = tmp = cnt = 0; for(re int i = T; i <= n; i += T) a[++ pax] = rev[i], vis[rev[i]] = 1; build(pax), get_dis(1, 0); G[T] = 2ll * mul(f[1], val[1]) % mod; dfs_mem(1, 0), tmp = mul(2, tmp); rep(i, 1, pax) tmp = (tmp + mul(dep[a[i]], mul(phi[Go[a[i]]], phi[Go[a[i]]]))) % mod; G[T] = (G[T] - 2ll * tmp % mod + mod) % mod; } } int main() { n = read(), init(n); rep(i, 1, n) Go[i] = read(), rev[Go[i]] = i; rep(i, 1, n - 1) add(read(), read()); dfs1(1, 0), dfs2(1, 1), mem(head, 0), solve(); rep(i, 1, n) ans = (ans + mul(G[i], F[i])) % mod; printf("%d", mul((ans + mod) % mod, qpow(mul(n, n - 1), mod - 2))); return 0; } ```