题解 P3369 【【模板】普通平衡树(Treap/SBT)】

2017-12-03 19:22:35


7个月前写过一篇题解,今天回来看下结果自己都看不下去,于是就来重写了。

递归版Splay

优点:不用维护父指针!!!刚开始学写非递归Splay的时候被父指针的维护坑了好久!!!

参考了大刘的《训练指南》。

实现

前排警告:前方存在大量结构体+指针

首先我用名为node的结构体保存节点:

const int inf = 0x7fffffff;
struct node *nil; // 哨兵节点,用于防止访问无效内存导致翻车
struct node
{
    node *ch[2];        // ch[0]是左儿子指针,ch[1]是右儿子指针
    int val, cnt, size; // 元素的值、元素个数(处理重复)、该节点构成的子树包含元素的个数
    int cmp(int v)      // 如代码所示,返回要寻找值为v的元素该向左走还是向右走
    {
        if (v == val)
            return -1;
        else
            return v < val ? 0 : 1;
    }
    int cmpkth(int k) // 同上,返回要寻找第k小元素该向左走还是向右走
    {
        if (k <= ch[0]->size)
            return 0;
        else if (k <= ch[0]->size + cnt)
            return -1;
        else
            return 1;
    }
    void pullup() { size = cnt + ch[0]->size + ch[1]->size; }      // 用于插入或删除后重新计算size
    node(int v) : val(v), cnt(1), size(1) { ch[0] = ch[1] = nil; } //  普通的构造函数
} * root;
void init() // 主要用来初始化哨兵节点
{
    nil = new node(0);
    root = nil->ch[0] = nil->ch[1] = nil;
    nil->size = nil->cnt = 0;
}

下面的说明中node既可以表示节点也可以表示

用0/1表示向左/向右,用ch[0]ch[1]表示左儿子/右儿子指针,用cmp(int v)返回往左走/往右走,用异或运算取相反方向,这些都是来自大刘《训练指南》的技巧。因为平衡树中对称的情形太多了,合理运用这些技巧可以压缩代码量。

伸展

所谓递归Splay,其实就是把寻找节点和伸展节点写在了一起,把递归寻找节点展开一层以后塞几行代码调用旋转。如果找不到这个值的节点,就会伸展最后一个访问到的节点。

void rotate(node *&t, int d) //传引用很重要!!
{
    node *k = t->ch[d ^ 1];
    t->ch[d ^ 1] = k->ch[d];
    k->ch[d] = t;
    t->pullup(), k->pullup(); // 注意此时k已经是t的父亲
    t = k;
}
void splay(int v, node *&t) // 在树t中寻找值为v的节点,并伸展成为t的根节点;传引用很重要!!
{
    int d = t->cmp(v);              //下一步该走的方向
    if (d != -1 && t->ch[d] != nil) //如果下一步可以走向一个合法结点
    {
        int d2 = t->ch[d]->cmp(v);               //下两步该走的方向
        if (d2 != -1 && t->ch[d]->ch[d2] != nil) //如果下两步可以走向一个合法结点
        {
            splay(v, t->ch[d]->ch[d2]); //先递归
            if (d == d2)
                rotate(t, d2 ^ 1), rotate(t, d ^ 1); // zig-zig
            else
                rotate(t->ch[d], d2 ^ 1), rotate(t, d ^ 1); //zig-zag
        }
        else
            rotate(t, d ^ 1); // zig
    }
    // else t已经是终点
}
void splaykth(int k, node *&t) // 同上,在树t中寻找第k小的节点,并伸展成为t的根节点;传引用很重要!!
{
    int d = t->cmpkth(k);
    if (d == 1)
        k -= t->ch[0]->size + t->cnt;
    if (d != -1)
    {
        int d2 = t->ch[d]->cmpkth(k);
        int k2 = (d2 == 1) ? k - (t->ch[d]->ch[0]->size + t->ch[d]->cnt) : k;
        if (d2 != -1)
        {
            splaykth(k2, t->ch[d]->ch[d2]);
            if (d == d2)
                rotate(t, d2 ^ 1), rotate(t, d ^ 1);
            else
                rotate(t->ch[d], d2 ^ 1), rotate(t, d ^ 1);
        }
        else
            rotate(t, d ^ 1);
    }
}

虽然写起来是递归,但本质还是自底向上的。百度可以找到真正的自顶向下伸展的方法。


既然Splay可以变来变去,那么很多操作都有“Splay特色”的写法:

求前驱/后继

逛了一圈发现都是暴力插入再求前驱/后继再删除的,下面介绍一个优雅的方法。或许是我原创的吧。

首先伸展X至根,如果X存在,根就会变成X,X的前驱就是左子树最大的值;再伸展左子树的最大值成为左子树的根,就是根节点的左儿子。求后继同理。

如果X不存在呢?可以证明查找节点时最后一个访问的节点必定是前驱或者后继。所以伸展后根就是X的前驱和后继之一。

当根是前驱的时候,前驱就是根,后继就是右子树的最小值;

当根是后继的时候,前驱就是左子树的最大值,后继就是根。

int lower(int v, node *&t = root) // 前驱
{
    splay(v, t);
    if (t->val >= v) // 根是X或是X的后驱
    {
        if (t->ch[0] == nil)
            return -inf;
        splay(inf, t->ch[0]); // 相当于伸展左子树的最大值
        return t->ch[0]->val;
    }
    else
        return t->val;
}
int upper(int v, node *&t = root) // 后驱
{
    splay(v, t);
    if (t->val <= v) // 根是X或是X的前驱
    {
        if (t->ch[1] == nil)
            return inf;
        splay(-inf, t->ch[1]); // 相当于伸展右子树的最小值
        return t->ch[1]->val;
    }
    else
        return t->val;
}

不严谨的证明:

用反证法证明。

想象一下对这棵树中序遍历得到一个有序序列。查找操作和二分查找是一样的,不过每次是以子树的根为分界点,进入左边或右边的序列(不包含该分界点)继续寻找。

由于该序列始终是连续的,若X不存在,最后序列必定会变成空的。考虑在此之前的上一步,若是在一个小于前驱的结点,则下一步必定是往包含前驱的方向缩小序列,故这一步不可能是最后一步(除非前驱不存在)。在一个大于后继的结点同理。

故经过的最后一个结点必定是前驱或后继(若存在)。

求排名

伸展X成为根,求左子树的元素数量+1即可

int getrank(int v, node *&t = root)
{
    splay(v, t);
    return t->ch[0]->size + 1;
}

求K大

int getkth(int k, node *&t = root)
{
    splaykth(k, t);
    return t->val;
}

splaykth(k, root)然后输出root->val即可。

分裂

将树t分为小于等于X和大于X两部分:

  • 若X在树上,先伸展X,这时候树的左子树都是小于X的元素,右子树都是大于X的元素。断开根和右子树的连接即可。

  • 若X不在树上,伸展操作将会把X的前驱或后继伸展至根(证明在下面)。只需判断下根是大于X还是小于X,决定断开根和左子树的连接还是右子树的连接。

node *split(int v, node *&t) // 分裂后,树t都是小于等于X的元素,返回的树都是大于X的元素
{
    if (t == nil)
        return nil;
    splay(v, t);

    node *t1, *t2; // 用于保存分裂后的两棵树
    if (t->val <= v)
        t1 = t, t2 = t->ch[1], t->ch[1] = nil;
    else
        t1 = t->ch[0], t2 = t, t->ch[0] = nil;
    t->pullup();
    t = t1;
    return t2;
}

合并

要合并的两棵树分别为T1和T2,则必须保证树T1的最大值严格小于树T2的最小值。

先伸展T1的最大值节点,这时候T1的根必然没有右子树,将T2接上去即可。

void merge(node *&t1, node *&t2) // 合并后得到的树是t1,t2会变为空树
{
    if (t1 == nil)
        swap(t1, t2);

    splay(inf, t1);
    t1->ch[1] = t2;
    t2 = nil;
    t1->pullup();
}

插入

为什么要把插入和删除放到最后面。因为这两个操作可以通过分裂和合并优雅地实现。

先将树分裂为小于或等于X的树T1和大于X的树T2。

由于T1的根没有右子树,故T1的根就是T1的最大值。检查T1的根是否等于X:若是,说明出现重复,计数加一;否则合并T1和新节点。之后重新合并新的T1和T2。

void insert(int v, node *&t = root)
{
    node *t2 = split(v, t);
    if (t->val == v)
        t->cnt++;
    else
    {
        node *nd = new node(v);
        merge(t, nd);
    }
    merge(t, t2);
}

删除

先将树分裂为小于或等于X的树T1和大于X的树T2。

由于T1的根没有右子树,故T1的根就是T1的最大值。检查T1的根是否为X且计数减一后为0:若是,用T1的左子树代替T1,并删除原T1的根;否则不处理。之后重新合并新的T1和T2。

void erase(int v, node *&t = root)
{
    node *t2 = split(v, t);
    if (t->val == v && --(t->cnt) < 1) // 命中节点,计数先减一,再判断是否要将节点删除
    {
        node *t3 = t->ch[0];
        delete t;
        t = t3;
    }
    merge(t, t2);
}

模板

// https://www.luogu.org/problem/show?pid=3369
// UPD: 2017/12/3
#include <iostream>
using namespace std;
namespace splay // 数据结构用namespace装着是个人习惯
{
const int inf = 0x7fffffff;
struct node *nil; // 哨兵节点,用于防止访问无效内存导致翻车
struct node
{
    node *ch[2];        // ch[0]是左儿子指针,ch[1]是右儿子指针
    int val, cnt, size; // 元素的值、元素个数(处理重复)、该节点构成的子树包含元素的个数
    int cmp(int v)      // 如代码所示,返回要寻找值为v的元素该向左走还是向右走
    {
        if (v == val)
            return -1;
        else
            return v < val ? 0 : 1;
    }
    int cmpkth(int k) // 同上,返回要寻找第k小元素该向左走还是向右走
    {
        if (k <= ch[0]->size)
            return 0;
        else if (k <= ch[0]->size + cnt)
            return -1;
        else
            return 1;
    }
    void pullup() { size = cnt + ch[0]->size + ch[1]->size; }      // 用于插入或删除后重新计算size
    node(int v) : val(v), cnt(1), size(1) { ch[0] = ch[1] = nil; } //  普通的构造函数
} * root;
void init() // 主要用来初始化哨兵节点
{
    nil = new node(0);
    root = nil->ch[0] = nil->ch[1] = nil;
    nil->size = nil->cnt = 0;
}
void rotate(node *&t, int d) //传引用很重要!!
{
    node *k = t->ch[d ^ 1];
    t->ch[d ^ 1] = k->ch[d];
    k->ch[d] = t;
    t->pullup(), k->pullup(); // 注意此时k已经是t的父亲
    t = k;
}
void splay(int v, node *&t) // 在树t中寻找值为v的节点,并伸展成为t的根节点;传引用很重要!!
{
    int d = t->cmp(v);              //下一步该走的方向
    if (d != -1 && t->ch[d] != nil) //如果下一步可以走向一个合法结点
    {
        int d2 = t->ch[d]->cmp(v);               //下两步该走的方向
        if (d2 != -1 && t->ch[d]->ch[d2] != nil) //如果下两步可以走向一个合法结点
        {
            splay(v, t->ch[d]->ch[d2]); //先递归
            if (d == d2)
                rotate(t, d2 ^ 1), rotate(t, d ^ 1); // zig-zig
            else
                rotate(t->ch[d], d2 ^ 1), rotate(t, d ^ 1); //zig-zag
        }
        else
            rotate(t, d ^ 1); // zig
    }
    // else t已经是终点
}
void splaykth(int k, node *&t) // 同上,在树t中寻找第k小的节点,并伸展成为t的根节点;传引用很重要!!
{
    int d = t->cmpkth(k);
    if (d == 1)
        k -= t->ch[0]->size + t->cnt;
    if (d != -1)
    {
        int d2 = t->ch[d]->cmpkth(k);
        int k2 = (d2 == 1) ? k - (t->ch[d]->ch[0]->size + t->ch[d]->cnt) : k;
        if (d2 != -1)
        {
            splaykth(k2, t->ch[d]->ch[d2]);
            if (d == d2)
                rotate(t, d2 ^ 1), rotate(t, d ^ 1);
            else
                rotate(t->ch[d], d2 ^ 1), rotate(t, d ^ 1);
        }
        else
            rotate(t, d ^ 1);
    }
}
// WARNING: split和merge必须要写得格外小心
node *split(int v, node *&t) // 分裂后,树t都是小于等于X的元素,返回的树都是大于X的元素
{
    if (t == nil)
        return nil;
    splay(v, t);

    node *t1, *t2; // 用于保存分裂后的两棵树
    if (t->val <= v)
        t1 = t, t2 = t->ch[1], t->ch[1] = nil;
    else
        t1 = t->ch[0], t2 = t, t->ch[0] = nil;
    t->pullup();
    t = t1;
    return t2;
}
void merge(node *&t1, node *&t2) // 合并后得到的树是t1,t2会变为空树
{
    if (t1 == nil)
        swap(t1, t2);

    splay(inf, t1);
    t1->ch[1] = t2;
    t2 = nil;
    t1->pullup();
}
void insert(int v, node *&t = root)
{
    node *t2 = split(v, t);
    if (t->val == v)
        t->cnt++;
    else
    {
        node *nd = new node(v);
        merge(t, nd);
    }
    merge(t, t2);
}
void erase(int v, node *&t = root)
{
    node *t2 = split(v, t);
    if (t->val == v && --(t->cnt) < 1) // 命中节点,计数先减一,再判断是否要将节点删除
    {
        node *t3 = t->ch[0];
        delete t;
        t = t3;
    }
    merge(t, t2);
}
int getrank(int v, node *&t = root)
{
    splay(v, t);
    return t->ch[0]->size + 1;
}
int getkth(int k, node *&t = root)
{
    splaykth(k, t);
    return t->val;
}
int lower(int v, node *&t = root) // 前驱
{
    splay(v, t);
    if (t->val >= v) // 根是X或是X的后驱
    {
        if (t->ch[0] == nil)
            return -inf;
        splay(inf, t->ch[0]); // 相当于伸展左子树的最大值
        return t->ch[0]->val;
    }
    else
        return t->val;
}
int upper(int v, node *&t = root) // 后驱
{
    splay(v, t);
    if (t->val <= v) // 根是X或是X的前驱
    {
        if (t->ch[1] == nil)
            return inf;
        splay(-inf, t->ch[1]); // 相当于伸展右子树的最小值
        return t->ch[1]->val;
    }
    else
        return t->val;
}
}
int main()
{
    ios::sync_with_stdio(false);
    splay::init();
    int n, opt, x;
    cin >> n;
    while (n--)
    {
        cin >> opt >> x;
        switch (opt)
        {
        case 1:
            splay::insert(x);
            break;
        case 2:
            splay::erase(x);
            break;
        case 3:
            cout << splay::getrank(x) << endl;
            break;
        case 4:
            cout << splay::getkth(x) << endl;
            break;
        case 5:
            cout << splay::lower(x) << endl;
            break;
        case 6:
            cout << splay::upper(x) << endl;
            break;
        }
    }
    return 0;
}