CF915F Imbalance Value of a Tree

2018-09-29 11:01:58


题意:给定一棵带点权树,定义 $F(x,y)$ 为 $x$ 到 $y$ 简单路径上最大值与最小值的差

   求 $\sum_{j=i}^nF(i,j)$  ( $n<=1e6$ )


易得 $sum(max-min)=sum(max)-sum(min)$

显然可以对于每一个点统计对答案的贡献

分别用并查集求以这个点为最大值和最小值的路径数量

两个过程是类似的,以最大值为例

把点按照点权从小到大、编号从大到小双关键字排序

从前向后扫,当前节点就是最大的

如果与它相连的节点比它的权值小,那么就把两个集合合并,统计答案

求最小值反过来即可

#include<cstdio>
#include<cstring>
#include<cctype>
#include<algorithm>
#define reg register
using namespace std;
typedef long long ll;
const int N=1e6+5;
struct E
{
    int to,nxt;
}edge[N<<1];
struct node
{
    int x,id;
}c[N];
int n,num,head[N],f[N],tot[N];
bool vis[N];
ll ans;
inline int read()
{
    int x=0,w=1;
    char c=getchar();
    while (!isdigit(c)&&c!='-') c=getchar();
    if (c=='-') c=getchar(),w=-1;
    while (isdigit(c))
    {
        x=(x<<1)+(x<<3)+c-'0';
        c=getchar();
    }
    return x*w;
}
inline void add_edge(int from,int to)
{
    edge[++num]=(E){to,head[from]};
    head[from]=num;
}
int find(int x){return f[x]==x?x:f[x]=find(f[x]);}
bool cmp1(node a,node b){return a.x==b.x?a.id>b.id:a.x<b.x;}
bool cmp2(node a,node b){return a.x==b.x?a.id>b.id:a.x>b.x;}
int main()
{
    n=read();
    for (reg int i=1;i<=n;i++) c[i]=(node){read(),i},f[i]=i,tot[i]=1;
    for (reg int i=1;i<n;i++)
    {
        int x=read(),y=read();
        add_edge(x,y); add_edge(y,x);
    }
    sort(c+1,c+n+1,cmp1);
    for (reg int j=1;j<=n;j++)
    {
        int k=c[j].id; vis[k]=1;
        for (reg int i=head[k];i;i=edge[i].nxt)
        {
            int v=edge[i].to;
            if (!vis[v]) continue; v=find(v);
            ans+=1ll*c[j].x*tot[k]*tot[v];
            f[v]=k; tot[k]+=tot[v];
        }
    }
    sort(c+1,c+n+1,cmp2); memset(vis,0,sizeof(vis));
    for (reg int i=1;i<=n;i++) f[i]=i,tot[i]=1;
    for (reg int j=1;j<=n;j++)
    {
        int k=c[j].id; vis[k]=1;
        for (reg int i=head[k];i;i=edge[i].nxt)
        {
            int v=edge[i].to;
            if (!vis[v]) continue; v=find(v);
            ans-=1ll*c[j].x*tot[k]*tot[v];
            f[v]=k; tot[k]+=tot[v];
        }
    }
    printf("%lld\n",ans);
    return 0;
}