P3181 [HAOI2016]找相同字符

2018-03-28 10:57:12


一道后缀数组的题目。

初学的蒟蒻也只能抄抄题解这样子。

最暴力的想法:在两个串中枚举极长子串

可以只考虑两原串某后缀的所有前缀

从而不重不漏找到所有子串

考虑到后缀的前缀,自然想到后缀数组

把两个串接在一起,中间隔一个其他字符(除小写字母外)

求出任意两后缀的LCP的值

可以用het数组和ST表实现,但依然不足以AC

进一步考虑:

由于het表示的是排名相邻的两个后缀LCP的长度,

所以任意两个后缀的LCP长度为按字典序排序后它们中间最小的het

也就是说排序后,一个后缀越往后数LCP的长度越小

这样,我们就可以用单调栈维护这个最小值

单调栈中有两个值:一个是het值,一个是位置i

i在这里充当一个系数(因为弹出的元素实际上还会放回去)

分A串的子串在前和B串的子串在前两种情况进行讨论

两种情况答案相加即可

#include<cstdio>
#include<cstring>
#include<algorithm>
#define reg register
using namespace std;
typedef long long ll;
const int N=4e5+5;
char c[N];
int n,l1,l2,top,sum[N];
pair<int,ll>stack[N];
struct HOU
{
    int n,m,a[N],top[N],rank[N],sa[N],tax[N],het[N];
    inline void qsort()
    {
        memset(tax,0,sizeof(tax));
        for (reg int i=1;i<=n;i++) ++tax[rank[i]];
        for (reg int i=1;i<=m;i++) tax[i]+=tax[i-1];
        for (reg int i=n;i>=1;i--)
          sa[tax[rank[top[i]]]--]=top[i];
    }
    inline void getSA()
    {
        for (reg int i=1;i<=n;i++) rank[i]=a[i],top[i]=i;
        m=127; qsort();
        for (reg int w=1,p=0;p<n;m=p,w<<=1)
        {
            p=0;
            for (reg int i=1;i<=w;i++) top[++p]=n-w+i;
            for (reg int i=1;i<=n;i++)
              if (sa[i]>w) top[++p]=sa[i]-w;
            qsort(); swap(rank,top);
            rank[sa[1]]=p=1;
            for (reg int i=2;i<=n;i++)
              if (top[sa[i-1]]==top[sa[i]]&&top[sa[i-1]+w]==top[sa[i]+w])
                rank[sa[i]]=p;
              else rank[sa[i]]=++p;
        }
        int k=0;
        for (reg int i=1;i<=n;i++)
        {
            k=(k?k-1:0);
            while (c[i+k]==c[sa[rank[i]-1]+k]) ++k;
            het[rank[i]]=k;
        }
    }
}R;
inline ll getans()//两个后缀的LCP(最长公共前缀长度)为按照字典序排序后它们之间最小的het
{
    ll ans=0;
    stack[0]=make_pair(1,0);
    for (reg int i=1;i<=R.n;i++)
      sum[i]=sum[i-1]+(R.sa[i]<=l1);
    for (reg int i=1;i<=R.n;i++)
    {
        while (top&&R.het[stack[top].first]>R.het[i]) --top;
        stack[++top]=make_pair(i,1ll*(sum[i-1]-sum[stack[top-1].first-1])*R.het[i]+stack[top-1].second);
        if (R.sa[i]>l1+1) ans+=stack[top].second;
    }
    top=0;
    for (reg int i=1;i<=R.n;i++)
      sum[i]=sum[i-1]+(R.sa[i]>l1+1);
    for (reg int i=1;i<=R.n;i++)
    {
        while (top&&R.het[stack[top].first]>R.het[i]) --top;
        stack[++top]=make_pair(i,1ll*(sum[i-1]-sum[stack[top-1].first-1])*R.het[i]+stack[top-1].second);
        if (R.sa[i]<=l1) ans+=stack[top].second;
    }
    return ans;
}
int main()
{
    scanf("%s",c+1); l1=strlen(c+1);//第一个串 
    scanf("%s",c+l1+2); c[l1+1]='z'+1;//把第二个串接在第一个串后面,中间隔一个其他字符 
    R.n=strlen(c+1);
    for (reg int i=1;i<=R.n;i++) R.a[i]=c[i]-'a'+1;
    R.getSA(); printf("%lld\n",getans());
    return 0;
}