最长上升子序列、最长下降子序列的DP算法由O(n^2)到O(nlogn)算法实现及其优化

2018-01-12 13:17:05


最长上升子序列、最长下降子序列的DP算法由O(n^2)到O(nlogn)算法实现及其优化

为了更好的介绍O(nlogn)算法,我们回顾一下一般的O(n^2)的算法。

令A[i]表示输入第i个元素,d[i]表示从A[1]到A[i]中以A[i]结尾的最长子序列长度。对于任意的0 < j <= i-1,如果A(j) < A(i),则A(i)可以接在A(j)后面形成一个以A(i)结尾的新的最长上升子序列。对于所有的 0 < j <= i-1,我们需要找出其中的最大值。

DP状态转移方程:d[i] = max{1, d[j] + 1} (j = 1, 2, 3, ..., i-1 且 A[j] < A[i]) ①

对于最长不下降子序列,怎样实现O(nlogn)算法呢?我们知道O(n^2)的算法复杂度高的原因就在于要更新d[i]的值,就必须在1~i-1中枚举找到最大的d[j]的值才能最终确定d[i]的值,于是我们可以这样思考,能否直接把1~i-1中最大的d[i]的值存储起来,从而实现直接检索呢?于是就有了如下类似贪心的算法(个人理解)。

view plain copy

/*预处理*/

const int MAXN = 40010;  
const int INF = 0x3f3f3f3f;  
int n;  
int A[MAXN], S[MAXN];  
int d[MAXN];  
void init()  
{  
    for(int i = 1; i <= n; i++) S[i] = INF; //这很重要,与upper_bound有关。  
    memset(d, 0, sizeof(d));  
}  

其中d[i]和①是一样的意义,而S数组表示的意义是:所有最长上升子序列长度为d[i]时的A[i]的最小值②,请仔细理解这一段话,即S[d[i]] = min{S[d[i]], A[i]}。 举例说明:

A[]: 1、2、3、-1、1、2、3、1

d[]: 1、2、3、1 、2、3、4、2

S[]: -1、1、2、3

可以看出,S序列是严格的递增序列,可以这样理解:d[i'] = 2的最小值一定比d[i'']值为1的最小值大,因为d[i''] > d[i'],就这么简单。那么知道了S的值有什么用呢?或许聪明的读者已经看出来了,对于最长不下降子序列,只要每次将一个A[i]的值在S数组中进行检索,返回的小于等于A[i]最后一个元素的下标的位置(或者“下一个下标的位置”③)一定就是d[i]的长度。

为什么呢?因为在这下标前面的元素一定是小于A[i]的,所以d[i]的值也就是返回的下标的值,不懂的可以用笔模拟一下,这也是前面我们为什么要这样定义②的目的所在。

这里还有一个地方要注意,就是最长上升子序列的问题和最长不下降子序列的问题,如问题:1、2、3、5、5的结果是4还是5?待会我会给出满意的解法。

另外正确的二分求上界的写法,我也会给出,写到这里,笔者不得不感叹:一个正确的二分查找也是很难写的。。④

int BSearch(int x, int y, int v) //二分求上界  
{  
    while(x <= y)  
    {  
        int mid = x+(y-x)/2;  
        if(S[mid] <= v) x = mid+1;  
        else y = mid-1;  
    }  
    return x;  
}  
void dp()  
{  
    init();  
    int ans = 0;  
    for(int i = 1; i <= n; i++)  
    {  
        int x = 1, y = i;  
        int pos = BSearch(x, y, A[i]);  
        d[i] = pos;  
        S[d[i]] = min(S[d[i]], A[i]);  
        ans = max(ans, d[i]);  
    }  
    printf("%d\n", ans);  
}  

如何求严格的最长上升子序列呢?其实我们只要在二分时,把A[m] <= v改为A[m] < v即可。 对于最长不上升子序列:模仿上面的定义,我们把S数组的定义改为所有最长上升子序列长度为d[i]时的A[i]的最大值,为什么要是最大值呢?因为S数组在这里应该遵循严格的递减序列才对,为了能够检索A[i],我们必须使得返回的下标一定就是d[i]的值,具体的实现方法:S的初始值赋为-INF,二分查找的过程需要改一下,S数组更新时使用max。

/*最长不上升子序列 POJ 1887*/

void init()  
{  
    for(int i = 1; i <= tot; i++) S[i] = -INF; //注意初始值   
    memset(d, 0, sizeof(d));  
}  
int BSearch(int x, int y, int v)  
{  
    while(x <= y)  
    {  
        int mid = x+(y-x)/2;  
        if(S[mid] >= v) x = mid+1; //注意看二分的变化   
        else y = mid-1;  
    }  
    return x;  
}  
void dp()  
{  
    init();  
    int ans = 0;  
    for(int i = 1; i <= tot; i++)  
    {  
        int x = 1, y = i;  
        int pos = BSearch(x, y, A[i]);  
        d[i] = pos;  
        S[d[i]] = max(S[d[i]], A[i]); //max  
        ans = max(ans, d[i]);  
    }  
    printf("  maximum possible interceptions: %d\n", ans);  
}  

这里的二分是检索A[I]在S数组中的下标,如果没有任何数比A[i]小,那么返回值应该是S当前数组的长度+1,不下降子序列刚好相反。 最长不下降子序列的优化:

对于最长不下降子序列,对于③我们发现,BSearch的过程相当于,对于一个整数b来说,是求小于等于b的最后一个元素的“下一个下标”R是什么?所以我们可以用到STL中的函数,upper_bound,这样我们不必再去手写二分,也就减少了一些代码量。

/*核心代码*/

void dp()  
{  
    init();  
    int ans = 0;  
    for(int i = 1; i <= n; i++)  
    {  
        int x = 1, y = i;  
        int pos = upper_bound(S, S+i, A[i]) - S; //upper_bound  
        d[i] = pos;  
        S[d[i]] = min(S[d[i]], A[i]);  
        ans = max(ans, d[i]);  
    }  
    printf("%d\n", ans);  
}  

最终,笔者还是不得不感叹:一个正确的二分查找也是很难写的。。。 对于④的解决方案:http://blog.csdn.net/wall\_f/article/details/8296194