bzoj4709 [JSOI2011]柠檬

2018-10-08 21:38:45


题意:给出一个长度为 $n$ 的序列,需要把它分成任意多段

   对于每一段,需要指定一个数 $x$ ,如果这一段中有 $k$ 个 $x$ ,那么收益为 $xk^2$

   最大化收益总和


决策单调性优化 $DP$

贴上原文链接:orz

首先可以发现选出的每一个区间的开头结尾一定是相同的数

如果不是相同的数,那么这个数就可以放到其他区间中贡献答案,不影响这个区间

并且我们只关注如何划分这个序列,选择区间的顺序不会产生影响


设 $f[i]$ 表示到第 $i$ 个数的最大收益

有转移方程: $f[i]=max\left\{f[j-1]+a[i]\times(s[i]-s[j]+1)^2\right\},(a[j]==a[i])$

其中 $s[i]$ 表示到第 $i$ 个位置 $a[i]$ 这个数的出现次数

这样转移是 $n^2$ 的,考虑优化


发现 $s[i]$ 是单调上升的, $(s[i]-s[j]+1)^2$ 也是单调上升的,并且增长越来越快

所以如果存在 $k<j$ 且 $k$ 比 $j$ 更优,则 $k$ 一直比 $j$ 更优

于是对每一个 $a[i]$ 用单调栈维护,当栈顶的第二个元素比第一个元素更优时,弹出栈顶元素


这样产生了一个问题:如果第二个元素不优于第一个,但第三个元素更优怎么办?

对于任意的 $j1<j2<i1<i2$ ,如果 $j1$ 超过 $i1$ 的时间小于 $j2$ 超过 $i1$ 的时间,那么 $j1$ 超过 $i2$ 也早于 $j2$ 超过 $i2$

可以二分求出任意一个 $j$ 超过 $k$ 的时间

在将 $i$ 入栈前,先判断第二个元素超过 $i$ 的时间是否小于第一个元素超过 $i$ 的时间,如果是就弹出栈顶,否则将 $i$ 入栈

这样可以保证每个元素超过上一个元素的时间也是单调的

#include<cstdio>
#include<cstring>
#include<cctype>
#include<vector>
#include<algorithm>
#define reg register
using namespace std;
typedef long long ll;
const int N=1e5+5,M=1e4+5;
int n,a[N],cnt[M],s[N];
ll f[N];
vector<int>p[M];
inline int read()
{
    int x=0,w=1;
    char c=getchar();
    while (!isdigit(c)&&c!='-') c=getchar();
    if (c=='-') c=getchar(),w=-1;
    while (isdigit(c))
    {
        x=(x<<1)+(x<<3)+c-'0';
        c=getchar();
    }
    return x*w;
}
inline ll calc(int x,int y){return f[x-1]+1ll*a[x]*y*y;}
inline int getpos(int x,int y)
{
    int l=1,r=n,ans=n+1;
    while (l<=r)
    {
        int mid=(l+r)>>1;
        if (calc(x,mid-s[x]+1)>=calc(y,mid-s[y]+1)) ans=mid,r=mid-1;
        else l=mid+1;
    }
    return ans;
}
int main()
{
    n=read();
    for (reg int i=1;i<=n;i++)
    {
        s[i]=++cnt[a[i]=read()];
        while (p[a[i]].size()>1&&getpos(p[a[i]][p[a[i]].size()-2],p[a[i]][p[a[i]].size()-1])<=getpos(p[a[i]][p[a[i]].size()-1],i)) p[a[i]].pop_back();
        p[a[i]].push_back(i);
        while (p[a[i]].size()>1&&getpos(p[a[i]][p[a[i]].size()-2],p[a[i]][p[a[i]].size()-1])<=s[i]) p[a[i]].pop_back();
        f[i]=calc(p[a[i]][p[a[i]].size()-1],s[i]-s[p[a[i]][p[a[i]].size()-1]]+1);
    }
    printf("%lld\n",f[n]);
    return 0;
}