Nemlit 的博客

Nemlit 的博客

By a konjac

题解 P3233 【[HNOI2014]世界树】

posted on 2019-09-30 22:40:28 | under 题解 |

这道题细节是真的多

看数据范围,这应该是一道虚树DP,我们先来想一下不用虚树怎么做

我们定义 $id[i]$ 为第i个点应该归哪一个议事处管理,且i到 $id[i]$ 的距离为 $dis[i]$

我们做两遍dfs,首先从下到上,用儿子更新父亲,再从上到下,用父亲更新儿子

更新过程十分简单,就类似于重链剖分的思路去更新就好了

然后这里要注意第一个细节,就是必须先用儿子更新父亲,再用父亲更新儿子,因为如果不这么做的话一个父亲可能有多个儿子,所以先更新儿子的话该儿子的'兄弟'不会更新到(画画图就理解了)

暴力DP就是这么做,那如果放在虚树上呢?

对于虚树上的点,我们仍然可以按照上述暴力DP方式来做

那么非虚树上的点呢?

首先虚树是保证了两个相邻的树点在原树中实在一条链上的

所以我们可以里用倍增的思想,求出两个相邻虚树点的分界点,分界点以上归上面的点管理,分界点一下同理

由于我们还要保证编号最小,所以这就是本题第二个坑点:我们需要判断两个虚树点id值的大小

具体实现中我们可以把分界处以下的点染成一种新的颜色,然后递归处理,我们可以保证先处理深度小的点再处理深度大的点

为什么要这么做呢?

因为分界点以下的点已经不归上方关键点管辖,所以我们可以用一种新的颜色覆盖点原来的颜色,表示被一个新点占领了

楼上 $chenkehan$ 大佬给了十分形象的图片

于是就可以愉快的码码码了

献上十分丑陋的代码:

#include<bits/stdc++.h>
using namespace std;
#define il inline
#define re register
il int read() {
    re int x = 0, f = 1; re char c = getchar();
    while(c < '0' || c > '9') { if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
    return x * f;
}
#define rep(i, s, t) for(re int i = s; i <= t; ++ i)
#define drep(i, s, t) for(re int i = t; i >= s; -- i)
#define Next(i, u) for(re int i = head[u]; i; i = e[i].next)
#define _ 300005
int n, m, Q, top[_], Top, st[_], son[_], size[_], dep[_], dfn[_], tot, cnt, head[_];
int id[_], ans[_], f[25][_], dis[_], vis[_], Size[_];
struct node {int a, id;}q[_];
struct edge {int v, next, w;}e[_ << 1];
il bool cmp(node a, node b) {return dfn[a.a] < dfn[b.a];}
il bool cmp1(node a, node b) {return a.id < b.id;}
il void add(int u, int v, int w) {
    e[++ cnt] = (edge){v, head[u], w}, head[u] = cnt;
    e[++ cnt] = (edge){u, head[v], w}, head[v] = cnt;
}
il void dfs1(int u, int fr) {
    f[0][u] = fr, dep[u] = dep[fr] + 1, size[u] = 1, dfn[u] = ++ tot;
    Next(i, u) if(e[i].v != fr) dfs1(e[i].v, u), size[u] += size[e[i].v];
}
il int LCA(int a, int b) {
    if(dep[a] < dep[b]) swap(a, b);
    drep(i, 0, 20) if(dep[a] - (1 << i) >= dep[b]) a = f[i][a];
    drep(i, 0, 20) if(f[i][a] != f[i][b]) a = f[i][a], b = f[i][b];
    return (a == b) ? a : f[0][a];
}
il int Dis(int a, int b) {return dep[a] + dep[b] - dep[LCA(a, b)] * 2;}
il void insert(int x) {
    if(Top == 1 && x != 1) return (void)(st[++ Top] = x);
    int lca = LCA(st[Top], x);  if(lca == x) return;
    while(Top > 1 && dep[st[Top - 1]] > dep[lca]) 
        add(st[Top - 1], st[Top], Dis(st[Top - 1], st[Top])), -- Top;
    if(dep[st[Top]] > dep[lca]) add(st[Top], lca, Dis(st[Top], lca)), -- Top;
    if(dep[st[Top]] < dep[lca]) st[++ Top] = lca;
    st[++ Top] = x;
}
il void dfs_mem(int u, int fr) {
    Next(i, u) if(e[i].v != fr) dfs_mem(e[i].v, u);
    head[u] = vis[u] = id[u] = dis[u] = Size[u] = ans[u] = 0;
}
il int get_fa(int u, int dis) {
    int now = 0;
    drep(i, 0, 20) if(now + (1 << i) <= dis) now += (1 << i), u = f[i][u];
    return u;
}
il void dfs_get(int u, int fr) {
    if(vis[u]) dis[u] = 0, id[u] = u, Size[u] = size[u];
    else dis[u] = 123456789, Size[u] = size[u];
    Next(i, u) {
        int v = e[i].v, w = e[i].w;  if(v == fr) continue;
        dfs_get(v, u), w += dis[v];
        if(dis[u] > w || (dis[u] == w && id[u] > id[v])) dis[u] = w, id[u] = id[v];
    }
}
il void dfs1_get(int u, int fr) {
    Next(i, u) {
        int v = e[i].v, w = dis[u] + e[i].w;  if(v == fr) continue;
        if(w < dis[v] || (w == dis[v] && id[v] > id[u])) dis[v] = w, id[v] = id[u];
        dfs1_get(v, u);
        if(id[u] == id[v]) Size[u] -= size[v];
        else {
            int x = get_fa(v, (dis[v] + dis[u] + e[i].w + (id[u] > id[v]) - 1) / 2 - dis[v]);//Attention: (dep[x] - dep[y])即e[i].w的意义是经过的边的数量!我就是因为这里调了**的
            Size[v] += size[x] - size[v], Size[u] -= size[x];
        }
        ans[id[v]] += Size[v];
    }
    if(u == 1) ans[id[1]] += Size[1];
}
int main() {
    n = read();
    rep(i, 1, n - 1) add(read(), read(), 0);
    Q = read(), dfs1(1, 1), memset(head, 0, sizeof(head));
    rep(i, 1, 20) rep(j, 1, n) f[i][j] = f[i - 1][f[i - 1][j]];
    while(Q --) {
        m = read(), st[Top = 1] = 1, cnt = 0;
        rep(i, 1, m) q[i].a = read(), vis[q[i].a] = 1, q[i].id = i;
        sort(q + 1, q + m + 1, cmp);
        rep(i, 1, m) insert(q[i].a);
        while(Top > 1) add(st[Top - 1], st[Top], Dis(st[Top - 1], st[Top])), -- Top;
        dfs_get(1, 0), dfs1_get(1, 0), sort(q + 1, q + m + 1, cmp1);
        rep(i, 1, m) printf("%d ", ans[q[i].a]);
        puts(""), dfs_mem(1, 0);
    }
    return 0;
}