题解 P4119 【[Ynoi2018]未来日记】
fr200110217102
2019-02-08 00:22:45
调了五个小时终于调出来了……不愧是Ynoi毒瘤题
~~一个不保证不会被毒瘤数据卡掉的算法。~~
---
# 分块+分块+并查集
~~看到这么诡异的修改操作显然想到分块~~
## 区间第$k$大
先考虑一下分块做区间第$k$大
可以二分/树状数组,然而复杂度会变成$O(n\sqrt{n}logn)$,基本上是$TLE$无疑了。。。
换个角度,可以维护每个数在区间内出现了多少次。这个可以用一个二维的$sum$数组维护:$sum[i][j]=$前$i$块中$j$出现的次数。这样区间内一个数$x$出现的次数就是中间整块$sum$的两个前缀和相减,再加上两边散块的出现次数。散块的出现个数可以用临时数组记下来。
然而直接枚举第$k$大并按顺序把出现次数加上是$O(n^2)$的。。。
于是再对值域分块,$cnt[i][x]=$第$i$块中$x$出现的次数(用于后面修改操作),$sumc[i][x]=$前$i$块中$x$出现的次数,$sums[i][x]=$前$i$块中$(x-1)S+1$到$xS$出现的次数之和。这里$S$取$\sqrt{100000}$(约为$317$)即可。~~(分块套分块)~~
每次求区间第$k$大的时候先枚举在哪一块,再在块内枚举。枚举到某个数发现加起来出现次数大于$k$了答案就是它了。
~~(突发奇想)其实这个分块套分块好像再优化一些就变成了树状数组上二分?~~反正在这道题上没用。
假设$n$与值域同阶。这样询问的复杂度就是$O(n\sqrt{n})$了。
## 区间修改
值域$1e5$还没提醒你什么吗……
每一块用一个并查集把所有值相同的点缩在一起。每块再维护一个$rt[i][x]$,存第$i$块中某个值为$x$的数的位置。
### 整块修改
$x$变$y$时,如果有$x$,就$fa[rt[i][x]]=rt[i][y]$;否则$rt[i][y]=rt[i][x]$。然后更新该块及后面的块的$sumc$和$sums$。注意要把各个块的$cnt[i][x]$的和存下来实时更新,要不就变成$O(n^2)$了!
### 散块修改
由于并查集的树形结构,修改时不能遇到一个值为$x$的就直接链到$rt[i][y]$上,否则可能会导致下面本来不应修改的值也被链到$rt[i][y]$的子树中。
一个解决方法是强行规定$rt[i][x]$是块内最左边的值为$x$的数的位置,不过这样维护起来需要较多的分类讨论。
都已经在一个散块里了,为什么不暴力呢?直接重构啊!
当然把整个散块的并查集拆掉然后重构还是会$TLE$飞的。实测$20-50$分不等。
只重构$x$和$y$的两棵子树就即可。
复杂度还是$O(n\sqrt{n})$。
另外因为空间限制,块长取$\sqrt{n}$会直接$MLE$掉。所以可以取一个稍大一些的值(好像块长取$600$左右的时候最优)。
~~所以这个做法能会被值域只有$2$个数的数据卡掉?~~
代码
```cpp
//#define ice
#pragma GCC optimize(3)
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#include<bits/stdc++.h>
#define rint register int
using namespace std;
inline int read(){
int res=0;char c=getchar(),f=1;
while(c<48||c>57){if(c=='-')f=0;c=getchar();}
while(c>=48&&c<=57)res=(res<<3)+(res<<1)+(c&15),c=getchar();
return f?res:-res;
}
const int N=1e5+5,M=170,S=320;
int n,m,a[N],sz,vsz,mxv;
int k,L[M],R[M],bel[N];
int l,r,x,y,op,lb,rb;
int fa[N],rt[M][N];
inline int find(int x){return fa[x]==x?x:fa[x]=find(fa[x]);}
int cnt[M][N],sumc[M][N],sums[M][S];
void build(int p){
for(rint i=L[p];i<=R[p];++i){
if(!rt[p][a[i]])rt[p][a[i]]=i;
else fa[i]=rt[p][a[i]];
++cnt[p][a[i]];
}
}
int sta[N];
void update(int p,int l,int r,int x,int y){
rint tmp=0,l0=L[p],r0=R[p],top=0;
rt[p][x]=rt[p][y]=0;
for(rint i=l0;i<=r0;++i){
a[i]=a[find(i)];
if(a[i]==x||a[i]==y)sta[++top]=i;
}
for(rint i=l;i<=r;++i)if(a[i]==x)a[i]=y,++tmp;
for(rint i=1;i<=top;++i)fa[sta[i]]=sta[i];
for(rint i=1,t,w;i<=top;++i){
t=sta[i],w=a[t];
if(!rt[p][w])rt[p][w]=t;
else fa[t]=rt[p][w];
}
cnt[p][x]-=tmp,cnt[p][y]+=tmp;
for(rint i=p;i<=k;++i){
sumc[i][x]-=tmp,sumc[i][y]+=tmp;
if(bel[x]!=bel[y])
sums[i][bel[x]]-=tmp,sums[i][bel[y]]+=tmp;
}
}
int c[N],s[S];
int main(){
#ifdef ice
freopen("1.in","r",stdin);
freopen("1.out","w",stdout);
#endif
n=read(),m=read();
sz=600,vsz=317,mxv=1e5;
k=(n-1)/sz+1;
for(rint i=1;i<=n;++i)a[i]=read(),fa[i]=i;
for(rint i=1;i<=mxv;++i)bel[i]=(i-1)/vsz+1;
for(rint i=1;i<=k;++i){
L[i]=(i-1)*sz+1,R[i]=min(i*sz,n);
build(i);
for(rint j=1;j<=vsz;++j)
sums[i][j]=sums[i-1][j];
for(rint j=1;j<=mxv;++j)
sumc[i][j]=sumc[i-1][j]+cnt[i][j];
for(rint j=L[i];j<=R[i];++j)
++sums[i][bel[a[j]]];
}
while(m--){
op=read();
if(op==1){
l=read(),r=read(),x=read(),y=read();
if(x==y)continue;
lb=(l-1)/sz+1,rb=(r-1)/sz+1;
if(lb==rb)update(lb,l,r,x,y);
else{
update(lb,l,R[lb],x,y);
update(rb,L[rb],r,x,y);
rint tmp,tmps=0;
for(rint i=lb+1;i<rb;++i){
if(rt[i][x]){
if(!rt[i][y])rt[i][y]=rt[i][x],a[rt[i][x]]=y;
else fa[rt[i][x]]=rt[i][y];
rt[i][x]=0,tmp=cnt[i][x],tmps+=tmp,
cnt[i][y]+=tmp,cnt[i][x]=0;
}
sumc[i][x]-=tmps,sumc[i][y]+=tmps;
if(bel[x]!=bel[y])
sums[i][bel[x]]-=tmps,sums[i][bel[y]]+=tmps;
}
for(rint i=rb;i<=k;++i){
sumc[i][x]-=tmps,sumc[i][y]+=tmps;
if(bel[x]!=bel[y])
sums[i][bel[x]]-=tmps,sums[i][bel[y]]+=tmps;
}
}
}else{
l=read(),r=read(),x=read();
lb=(l-1)/sz+1,rb=(r-1)/sz+1;
if(lb==rb){
for(rint i=l;i<=r;++i)
a[i]=a[find(i)],++c[a[i]],++s[bel[a[i]]];
rint vl,vr,tmp=0;
for(rint i=1;i<=vsz;++i){
tmp+=s[i];
if(tmp>=x){tmp-=s[i],vl=(i-1)*vsz+1,vr=i*vsz;break;}
}
for(rint i=vl;i<=vr;++i){
tmp+=c[i];
if(tmp>=x){printf("%d\n",i);break;}
}
for(rint i=l;i<=r;++i)
--c[a[i]],--s[bel[a[i]]];
}
else{
for(rint i=l;i<=R[lb];++i){
a[i]=a[find(i)];
++c[a[i]],++s[bel[a[i]]];
}
for(rint i=L[rb];i<=r;++i){
a[i]=a[find(i)];
++c[a[i]],++s[bel[a[i]]];
}
rint vl,vr,tmp=0;
for(rint i=1;i<=vsz;++i){
tmp+=s[i]+sums[rb-1][i]-sums[lb][i];
if(tmp>=x){tmp-=s[i]+sums[rb-1][i]-sums[lb][i],vl=(i-1)*vsz+1,vr=i*vsz;break;}
}
for(rint i=vl;i<=vr;++i){
tmp+=c[i]+sumc[rb-1][i]-sumc[lb][i];
if(tmp>=x){printf("%d\n",i);break;}
}
for(rint i=l;i<=R[lb];++i)
--c[a[i]],--s[bel[a[i]]];
for(rint i=L[rb];i<=r;++i)
--c[a[i]],--s[bel[a[i]]];
}
}
}return 0;
}
```