Irelia

Irelia

逝去的是刀锋,不变的是意志

题解 P2633 【Count on a tree】

posted on 2019-08-31 15:31:47 | under 题解 |

值域主席树能做的事情,可持久化Trie都能做到

所以,我们可以用可持久化Trie来做这道题

解题思路

  • 对于每一个树上的点,建一个版本的Trie,来存储从根到该点路径上的所有数。可见,在每一个节点上,我们需要以父节点的Trie作为原版本来添加

  • 为了得到 $a$ 、 $b$ 路径上的所有数的信息,我们可以用 $a+b-lca-lcafa$ 来求出。直接在这几个Trie上一起往下走路,对 $size$ 做和差即可,基本框架和Trie求全局第 $k$ 小一样,关于Trie求区间第 $k$ 小可以看这篇题解

程序实现

树剖+Trie

// luogu-judger-enable-o2
#include <iostream>

using std :: cin;
using std :: cout;
using std :: swap;
using std :: cerr;
typedef unsigned int U;

const int MAXN = 1e5 + 5;

struct Edge{
    int to;
    Edge *nxt;

    Edge() {}

    Edge(int to, Edge *nxt) : to(to), nxt(nxt) {}
}epool[MAXN << 1];

Edge *head[MAXN];

Edge *NewEdge(int to, Edge *nxt) {
    static int cnt = 0;
    epool[cnt] = Edge(to, nxt);
    return &epool[cnt++];
}

void AddEdge(int u, int v) {
    head[u] = new Edge(v, head[u]);
}

int n, m;
int a[MAXN];
int fa[MAXN], dep[MAXN], son[MAXN], siz[MAXN], top[MAXN], idx[MAXN];
int idcnt;

void Dfs1(int u, int las, int depth) {
    fa[u] = las;
    dep[u] = depth;
    siz[u] = 1;
    for (Edge *e = head[u]; e; e = e->nxt) {
        int v = e->to;
        if (v == fa[u]) continue;
        Dfs1(v, u, depth + 1);
        siz[u] += siz[v];
        if (siz[v] > siz[son[u]]) son[u] = v;
    }
}

void Dfs2(int u, int topf) {
    top[u] = topf;
    idx[u] = ++idcnt;
    if (!son[u]) return;
    Dfs2(son[u], topf);
    for (Edge *e = head[u]; e; e = e->nxt) {
        int v = e->to;
        if (v == fa[u] || v == son[u]) continue;
        Dfs2(v, v);
    }
}

int Lca(int x, int y) {
    while (top[x] != top[y]) {
        if (dep[top[x]] > dep[top[y]]) swap(x, y);
        y = fa[top[y]];
    }
    if (dep[x] > dep[y]) swap(x, y);
    return x;
}

struct Node{
    int siz;
    Node *ch[2];
}npool[10000000];

void Update(Node *now) {
    now->siz = (now->ch[0] ? now->ch[0]->siz : 0) + (now->ch[1] ? now->ch[1]->siz : 0);
}

Node *NewNode() {
    static int cnt = 0;
    npool[cnt].siz = 0;
    npool[cnt].ch[0] = NULL;
    npool[cnt].ch[1] = NULL;
    return &npool[cnt++];
}

Node *Copy(Node *now) {
    Node *ret = NewNode();
    *ret = *now;
    return ret;
}

Node *rt[MAXN];

void Insert(Node *&now, U num, int base) {
    if (!now) now = NewNode();
    else now = Copy(now);
    if (base == 0) {
        now->siz++;
        return;
    }
    int f = (num & (1U << base - 1)) ? 1 : 0;
    Insert(now->ch[f], num, base - 1);
    Update(now);
}

U QueryKth(Node *now1, Node *now2, Node *now3, Node *now4, int base, int k, U res) {
    if (base == 0) return res;
    int ls1 = (now1 && now1->ch[0]) ? now1->ch[0]->siz : 0;
    int ls2 = (now2 && now2->ch[0]) ? now2->ch[0]->siz : 0;
    int ls3 = (now3 && now3->ch[0]) ? now3->ch[0]->siz : 0;
    int ls4 = (now4 && now4->ch[0]) ? now4->ch[0]->siz : 0;
    int ls = ls1 + ls2 - ls3 - ls4;
    if (k <= ls) return QueryKth(now1 ? now1->ch[0] : NULL, now2 ? now2->ch[0] : NULL, now3 ? now3->ch[0] : NULL, now4 ? now4->ch[0] : NULL, base - 1, k, res);
    else return QueryKth(now1 ? now1->ch[1] : NULL, now2 ? now2->ch[1] : NULL, now3 ? now3->ch[1] : NULL, now4 ? now4->ch[1] : NULL, base - 1, k - ls, res + (1U << base - 1));
}

void Print(Node *now, int base, U res) {
    if (!now) return;
    if (base == 0) cerr << res << " ";
    Print(now->ch[0], base - 1, res);
    Print(now->ch[1], base - 1, res | (1U << base - 1));
}

void Dfs3(int u) {
    rt[u] = rt[fa[u]];
    Insert(rt[u], a[u], 31);
    for (Edge *e = head[u]; e; e = e->nxt) {
        int v = e->to;
        if (v == fa[u]) continue;
        Dfs3(v);
    }
}

void Init() {
    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> a[i];
    for (int i = 1; i <= n - 1; i++) {
        int x, y;
        cin >> x >> y;
        AddEdge(x, y);
        AddEdge(y, x);
    }
    Dfs1(1, 0, 1);
    Dfs2(1, 1);
    rt[0] = NewNode();
    Dfs3(1);
    // for (int i = 0; i <= n; i++) {
    //     cerr << i << " :\n";
    //     Print(rt[i], 31, 0);
    //     cerr << "\n";
    // }
    // cerr << "*";
}

void Work() {
    int ans = 0;
    for (int i = 1; i <= m; i++) {
        int x, y, k;
        cin >> x >> y >> k;
        x ^= ans;
        int lca = Lca(x, y);
        ans = QueryKth(rt[x], rt[y], rt[lca], rt[fa[lca]], 31, k, 0);
        cout << ans << "\n";
    }
}

int main() {
    Init();
    Work();
    return 0;
}