题解 P1736 【创意吃鱼法】
刘备
2017-11-27 23:42:57
本蒟蒻终于是A了这道题!!!
做完后翻开题解,发现这道题的做法实在是多,有dp,暴搜,还有二分答案再配上二维前缀和。
这里我们**详细**介绍一种**二维前缀和+dp+二分查找**的方法。
时间复杂度:O(n^2\*logn),加了读入优化差不多650ms。(再次感叹这题读优的强大,整整快了700ms)
**好了,进入正题**
刚看到这道题,第一件要做的事情就是将题目简化:
求一个01矩阵中,满足**左/右对角线上全是1,其余位置都是0的正方形子矩阵中的最大的那个的对角线长度**
当看到子矩阵,我就想到了二维前缀和,二维前缀和的原理请自行百度或看以下代码,很好理解的(与一维前缀和相似),这里我们只讲一下其大概功能: **通过O(n^2)初始化,O(1)查询所求子矩阵中的元素和**
好,既然已经有了这么一个利器,我们考虑下如何dp:
首先这题我们要求最长的符合左对角线或右对角线,其实我们可以分开求,以下以左对角线举例,右对角线只要复制粘贴并稍改以下就OK了。
**状态设计:**
我们设dp[i][j]表示以(i,j)结尾的符合条件的最长的左对角线长度,那么最终答案就是对所有dp[i][j]取最大值,然后同理求右对角线,然后输出max(左对角线最大值,右对角线最大值)即可。
**状态转移:**
继续以左对角线为例。我们找到一个在地图上为1的点(i,j)(如果在地图上为0,那么根据定义其dp值必定为0),发现它一定与dp[i-1][j-1]
有关,但是并不是dp[i][j]=dp[i-1][j-1]+1那么简单,因为有可能加上这个点后就不符合条件了。这时有的同学就会说:那简单,若不符合条件,就把dp[i][j]改为1(即它本身组成的矩阵),**错!**
这两种想法其实都太极端,一个是全部继承,一个是全不继承。但其实它还有可能**部分继承dp[i-1][j-1]的值**,这里大家一定要好好理解一下。
最后一个问题,**如何算出究竟继承了多少**?最初的想法是暴力,从全部继承到全不继承枚举,这样显然太慢。我们发现其实它具有单调性:**如果此时的矩阵满足条件,那么更小的一定满足条件**。根据这个,我们可以通过二分查找算出满足条件的最大对角线长度并继承,配上二维前缀和,一次的操作不超过O(logn)。
好了,理论部分结束,接下来上代码,如果大家还有什么没有理解的可以从代码中寻找答案。
```cpp
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
using namespace std;
#define N 2510
int dp[N][N],sum[N][N]; //sum表示从(1,1)到(i,j)的元素和
bool map[N][N];
int maxx=0;
int getsum(int i,int j,int x,int y) {
return sum[x][y]-sum[x][j-1]-sum[i-1][y]+sum[i-1][j-1]; //二维前缀和的精华操作,表示对角线从(i,j)到(x,y)的子矩阵的元素和
}
int read() {
int ans=0,f=1;
char ch=getchar();
while(!isdigit(ch)) {
if(ch=='-') f=-1;
ch=getchar();
}
while(isdigit(ch)) {
ans=ans*10+ch-'0';
ch=getchar();
}
return f*ans;
}
int main() {
int i,j,n=read(),m=read();
for(i=1;i<=n;i++) {
int num=0;
for(j=1;j<=m;j++) {
map[i][j]=read();
num+=map[i][j];
sum[i][j]=sum[i-1][j]+num; //初始化二维前缀和
}
}
//左对角线操作
for(i=1;i<=n;i++)
for(j=1;j<=m;j++)
if(map[i][j]) {
dp[i][j]=1; //初始化为1,在这里没什么意义,但在下面求右对角线时是必要的
if(dp[i-1][j-1]) {
int num=dp[i-1][j-1];
int l=0,r=num,ans=0;
//以下即为二分查找,功能就是找出符合条件的最长对角线长度
while(l<=r) {
int mid=l+r>>1;
if(getsum(i-mid,j-mid,i,j)==mid+1) ans=mid,l=mid+1;
else r=mid-1;
}
dp[i][j]=max(dp[i][j],ans+1);
}
maxx=max(maxx,dp[i][j]);
}
//右对角线长度
for(i=1;i<=n;i++)
for(j=1;j<=m;j++)
if(map[i][j]) {
dp[i][j]=1;
if(dp[i-1][j+1]) {
int num=dp[i-1][j+1];
int l=0,r=num,ans=0;
while(l<=r) {
int mid=l+r>>1;
if(getsum(i-mid,j,i,j+mid)==mid+1) ans=mid,l=mid+1; //注意getsum()内填入的顺序,我一开始就被坑了
else r=mid-1;
}
dp[i][j]=max(dp[i][j],ans+1);
}
maxx=max(maxx,dp[i][j]);
}
printf("%d\n",maxx);
return 0;
}
```