ouuan 的博客

ouuan 的博客

题解 CF1172E 【Nauuo and ODT】

posted on 2019-06-08 13:56:06 | under 题解 |

对每种颜色分别考虑不含该颜色的简单路径条数。

令当前处理的颜色为 $c$ ,把颜色为 $c$ 的视为白色,不是 $c$ 的视为黑色,那么不含 $c$ 的路径条数就是每个黑联通块的大小的平方和,修改就是当颜色是 $c$ $\leftrightarrow$ 颜色不是 $c$ 时翻转一个点的颜色。所以,问题转化成了黑白两色的树,单点翻转颜色,维护黑联通块大小的平方和。这个转化后的问题可以用很多数据结构做,比如:lxl:top tree 随便维护。这篇题解里使用 Link/cut Tree.

对每个点维护子树大小,儿子大小平方和,在 link/cut 的时候更新答案即可。有一个大家熟知的 trick,就是每个黑点向父亲连边,这样真正的联通块就是 Link/cut Tree 里的联通块删掉根。

具体看图吧,图讲的挺清楚的:

LCT 的部分就是这样,计算答案的时候先初始化一个全是黑点的树,离线处理每个颜色,处理完一个颜色反着改回去。

#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>

using namespace std;

typedef long long ll;

const int N = 400010;

struct Node
{
    int fa, ch[2], siz, sizi;
    ll siz2i;
    ll siz2() { return (ll) siz * siz; }
} t[N];

bool nroot(int x);
void rotate(int x);
void Splay(int x);
void access(int x);
int findroot(int x);
void link(int x);
void cut(int x);
void pushup(int x);

void add(int u, int v);
void dfs(int u);

int head[N], nxt[N << 1], to[N << 1], cnt;
int n, m, c[N], f[N];
ll ans, delta[N];
bool bw[N];
vector<int> mod[N][2];

int main()
{
    int i, j, u, v;
    ll last;

    scanf("%d%d", &n, &m);

    for (i = 1; i <= n; ++i)
    {
        scanf("%d", c + i);
        mod[c[i]][0].push_back(i);
        mod[c[i]][1].push_back(0);
    }

    for (i = 1; i <= n + 1; ++i) t[i].siz = 1;

    for (i = 1; i < n; ++i)
    {
        scanf("%d%d", &u, &v);
        add(u, v);
        add(v, u);
    }

    for (i = 1; i <= m; ++i)
    {
        scanf("%d%d", &u, &v);
        mod[c[u]][0].push_back(u);
        mod[c[u]][1].push_back(i);
        c[u] = v;
        mod[v][0].push_back(u);
        mod[v][1].push_back(i);
    }

    f[1] = n + 1;
    dfs(1);

    for (i = 1; i <= n; ++i) link(i);

    for (i = 1; i <= n; ++i)
    {
        if (!mod[i][0].size())
        {
            delta[0] += (ll)n * n;
            continue;
        }
        if (mod[i][1][0])
        {
            delta[0] += (ll)n * n;
            last = (ll)n * n;
        } else
            last = 0;
        for (j = 0; j < mod[i][0].size(); ++j)
        {
            u = mod[i][0][j];
            if (bw[u] ^= 1)
                cut(u);
            else
                link(u);
            if (j == mod[i][0].size() - 1 || mod[i][1][j + 1] != mod[i][1][j])
            {
                delta[mod[i][1][j]] += ans - last;
                last = ans;
            }
        }
        for (j = mod[i][0].size() - 1; ~j; --j)
        {
            u = mod[i][0][j];
            if (bw[u] ^= 1)
                cut(u);
            else
                link(u);
        }
    }

    ans = (ll) n * n * n;
    for (i = 0; i <= m; ++i)
    {
        ans -= delta[i];
        printf("%I64d ", ans);
    }

    return 0;
}

bool nroot(int x) { return x == t[t[x].fa].ch[0] || x == t[t[x].fa].ch[1]; }

void rotate(int x)
{
    int y = t[x].fa;
    int z = t[y].fa;
    int k = x == t[y].ch[1];
    if (nroot(y)) t[z].ch[y == t[z].ch[1]] = x;
    t[x].fa = z;
    t[y].ch[k] = t[x].ch[k ^ 1];
    t[t[x].ch[k ^ 1]].fa = y;
    t[x].ch[k ^ 1] = y;
    t[y].fa = x;
    pushup(y);
    pushup(x);
}

void Splay(int x)
{
    while (nroot(x))
    {
        int y = t[x].fa;
        int z = t[y].fa;
        if (nroot(y)) (x == t[y].ch[1]) ^ (y == t[z].ch[1]) ? rotate(x) : rotate(y);
        rotate(x);
    }
}

void access(int x)
{
    for (int y = 0; x; x = t[y = x].fa)
    {
        Splay(x);
        t[x].sizi += t[t[x].ch[1]].siz;
        t[x].sizi -= t[y].siz;
        t[x].siz2i += t[t[x].ch[1]].siz2();
        t[x].siz2i -= t[y].siz2();
        t[x].ch[1] = y;
        pushup(x);
    }
}

int findroot(int x)
{
    access(x);
    Splay(x);
    while (t[x].ch[0]) x = t[x].ch[0];
    Splay(x);
    return x;
}

void link(int x)
{
    int y = f[x];
    Splay(x);
    ans -= t[x].siz2i + t[t[x].ch[1]].siz2();
    int z = findroot(y);
    access(x);
    Splay(z);
    ans -= t[t[z].ch[1]].siz2();
    t[x].fa = y;
    Splay(y);
    t[y].sizi += t[x].siz;
    t[y].siz2i += t[x].siz2();
    pushup(y);
    access(x);
    Splay(z);
    ans += t[t[z].ch[1]].siz2();
}

void cut(int x)
{
    int y = f[x];
    access(x);
    ans += t[x].siz2i;
    int z = findroot(y);
    access(x);
    Splay(z);
    ans -= t[t[z].ch[1]].siz2();
    Splay(x);
    t[x].ch[0] = t[t[x].ch[0]].fa = 0;
    pushup(x);
    Splay(z);
    ans += t[t[z].ch[1]].siz2();
}

void pushup(int x)
{
    t[x].siz = t[t[x].ch[0]].siz + t[t[x].ch[1]].siz + t[x].sizi + 1;
}

void add(int u, int v)
{
    nxt[++cnt] = head[u];
    head[u] = cnt;
    to[cnt] = v;
}

void dfs(int u)
{
    int i, v;
    for (i = head[u]; i; i = nxt[i])
    {
        v = to[i];
        if (v != f[u])
        {
            f[v] = u;
            dfs(v);
        }
    }
}