题解 SP2916 【GSS5 - Can you answer these queries V】

energy2002

2018-09-29 20:41:13

Solution

大概可以算是线段树维护区间最大子段和的模板题了吧... ## 题意 对于长度为$n$的序列,回答$m$个询问,每个询问查询左端点在$[l_1,r_1]$之中,右端点在$[l_2,r_2]$之中的所有区间和的最大值,即: $$\max_{l_1\leq l\leq l_2,r_1\leq r\leq r_2} \sum_{i=l}^{j} a[i]$$ ( 其中保证$l_1\leq l_2;r_1\leq r_2;n,m,|a[i]|\leq 1E4$ ). ## solution: ## 0 如何建立线段树 ### 0.0 节点信息 建立线段树时,对于每个节点,额外存储4个信息$pre,mid,suf,sum$; 详细信息如图: ![](https://cdn.luogu.com.cn/upload/pic/34790.png) 可以看出,父节点的$pre$可由其左儿子的$pre$或其左儿子的$sum$加右儿子的$pre$更新,在此情况中,父节点的$pre$由前者更新; ### 0.1 更新 类比父节点$pre$的更新方式,可以写出其$mid$与$suf$的更新方式: (为了方便,我把这4个数据封装到一个struct中了...) ```cpp struct trans{ int pre,mid,suf,sum; }; trans merge(trans s1,trans s2){ trans ans; ans.pre=max(s1.pre,s2.pre+s1.sum); ans.mid=max(max(s1.mid,s2.mid),s1.suf+s2.pre); //父节点的mid要由3个数据更新 ans.suf=max(s1.suf+s2.sum,s2.suf); ans.sum=s1.sum+s2.sum; //别忘了更新sum return ans; } ``` 而对于叶节点,在建树时顺便更改其数据即可,这样就可以很容易地写出建树的代码了: ```cpp void build(int rt,int l,int r){ p[rt].l=l;p[rt].r=r; if(l==r){ p[rt].dat.pre=p[rt].dat.mid=p[rt].dat.suf=p[rt].dat.sum=dat[l]; //我这里规定每个子序列不可以为空,如果这里规定序列可以为空,那相应的,查询时也应更改。 return; } int mid=(l+r)/2; build(rt*2,l,mid); build(rt*2+1,mid+1,r); update(rt); return; } ``` Tips:如果还没有写过这类问题,建议先做[GSS1](https://www.luogu.org/problemnew/show/SP1043) ## 1 区间最大字段和查询 与普通线段树差不多,此线段树的查询结果也由多个节点的值合并而成: ![](https://cdn.luogu.com.cn/upload/pic/34807.png) 如图,比如查询的区间是$[3,8]$,那么,查询的答案将由存储范围为$[3,4]$和$[5,8]$的两个区间合并而成; 我们不难发现,在建立线段树和查询时信息的合并方式是相同的,那么,我们也就可以很容易地写出查询的代码了: ```cpp trans query_sub_max(int rt,int l,int r){ if(l>r){ //这个是为了防止之后的查询中出现的BUG用的 return (trans){0,0,0,0}; } if(p[rt].l==l&&p[rt].r==r){ return p[rt].dat; } int mid=p[rt*2].r; if(r<=mid)return query_sub_max(rt*2,l,r); else if(l>mid)return query_sub_max(rt*2+1,l,r); //如果要查寻的区间没有跨越其中一个儿子的范围,直接返回即可 return merge(query_sub_max(rt*2,l,mid),query_sub_max(rt*2+1,mid+1,r)); } ``` 可以看出,查询一次的复杂度为$O(log(n))$ ## 2 对左右端点有限制的区间的查询 本题对左右端点的范围有限制,不能一次查询出结果,于是可以考虑分情况多次查询: ### 2.1 两个区间没有交集的情况 ![](https://cdn.luogu.com.cn/upload/pic/34817.png) 如图,可以看出,蓝框中的数是无论如何要选上的,为了最大化结果,只能最大化区间$[l_1,r_1]$与$[l_2,r_2]$的选中部分的和,所以,不难写出这种情况的代码: ```cpp if(r1<l2){ int tmp=query_sub_max(1,l1,r1).suf; //区间[l1,r1]的最大后缀和 tmp+=query_sub_max(1,r1+1,l2-1).sum; //蓝框部分的和 tmp+=query_sub_max(1,l2,r2).pre; //区间[l2,r2]的最大前缀和 return tmp; } ``` ### 2.2 两个区间有交集的情况 而这种情况就比较复杂了,可分为3种情况,如图: ![](https://cdn.luogu.com.cn/upload/pic/34820.png) 我们可以发现,其实情况1与情况2是可以用同一种方式求出的,即区间$[l_1,l_2]$的$suf$+区间$[l_2,r_2]$的$pre$(或区间$[l_1,r_1]$的$suf$+区间$[r_1,r_2]$的$pre$); 而情况3就很好写了,只需求出区间$[l_2,r_1]$的$mid$即可; **Tips: 注意特判区间$[l_1,r_1],[r_1,r_2]$是否存在,即$l_1$与$l_2$,$r_1$与$r_2$是否相等。** 代码如下: ``` int ans=query_sub_max(1,l2,r1).mid; if(l1<l2)ans=max(ans,query_sub_max(1,l1,l2).suf+query_sub_max(1,l2,r2).pre-dat[l2]); if(r2>r1)ans=max(ans,query_sub_max(1,l1,r1).suf+query_sub_max(1,r1,r2).pre-dat[r1]); return ans; } ``` ## 最后,附上AC代码: ```cpp #include<bits/stdc++.h> #define debug(x) cerr<<#x<<" = "<<x<<endl; #define db(x) debug(x) using namespace std; int dat[10005],n; struct trans{ int pre,mid,suf,sum; int fin(){ return max(max(pre,suf),mid); } trans DB(){ db(pre);db(mid);db(suf);db(sum); return *this; } }; trans merge(trans s1,trans s2){ trans ans; ans.pre=max(s1.pre,s2.pre+s1.sum); ans.mid=max(max(s1.mid,s2.mid),s1.suf+s2.pre); ans.suf=max(s1.suf+s2.sum,s2.suf); ans.sum=s1.sum+s2.sum; return ans; } struct node{ int l,r; trans dat; }p[40005]; void update(int rt){ p[rt].dat=merge(p[rt*2].dat,p[rt*2+1].dat); return; } void build(int rt,int l,int r){ p[rt].l=l;p[rt].r=r; if(l==r){ p[rt].dat.pre=p[rt].dat.mid=p[rt].dat.suf=p[rt].dat.sum=dat[l]; return; } int mid=(l+r)/2; build(rt*2,l,mid); build(rt*2+1,mid+1,r); update(rt); return; } trans query_sub_max(int rt,int l,int r){ if(l>r){ return (trans){0,0,0,0}; } if(p[rt].l==l&&p[rt].r==r){ return p[rt].dat; } int mid=p[rt*2].r; if(r<=mid)return query_sub_max(rt*2,l,r); else if(l>mid)return query_sub_max(rt*2+1,l,r); return merge(query_sub_max(rt*2,l,mid),query_sub_max(rt*2+1,mid+1,r)); } int query(int l1,int r1,int l2,int r2){ if(r1<l2){ int tmp=query_sub_max(1,l1,r1).suf; tmp+=query_sub_max(1,r1+1,l2-1).sum; tmp+=query_sub_max(1,l2,r2).pre; return tmp; } int ans=query_sub_max(1,l2,r1).mid; if(l1<l2)ans=max(ans,query_sub_max(1,l1,l2).suf+query_sub_max(1,l2,r2).pre-dat[l2]); if(r2>r1)ans=max(ans,query_sub_max(1,l1,r1).suf+query_sub_max(1,r1,r2).pre-dat[r1]); return ans; } int main(){ int t; scanf("%d",&t); while(t--){ scanf("%d",&n); for(int i=1;i<=n;i++)scanf("%d",&dat[i]); build(1,1,n); int m; scanf("%d",&m); while(m--){ int x1,x2,y1,y2; scanf("%d%d%d%d",&x1,&y1,&x2,&y2); printf("%d\n",query(x1,y1,x2,y2)); } } return 0; } ```