ouuan 的博客

ouuan 的博客

题解 CF235C 【Cyclical Quest】

posted on 2019-02-27 13:49:49 | under 题解 |

欢迎到我的博客阅读。

题目链接

洛谷

CF problemset

CF contest

题意简述

给你一个字符串 $s$ 和 $n$ 个字符串 $x_{1..n}$ ,对每个 $x_i$ ,求有多少个 $s$ 的子串可以由 $x_i$ 旋转得到。

旋转一个字符串就是把它的一个前缀移到后面,如 abcd 可以旋转得到的字符串有 abcdbcdacdabdabc

简要做法

对 $s$ 建 SAM,把 $x_i$ 旋转得到的每个字符串用 SAM 读入,就可以求答案了。(SAM 求子串出现次数是经典问题,可以参考我的博客

分开读入每个 $x_i$ 旋转得到的字符串显然会超时,然而,SAM 读入字符串是支持删除首字符的:记录当前读入的长度 $l$ 以及所处状态 $u$ ,删除字符就把 $l$ 减一,若减一后 $l=len(parent(u))$ ,则转移到 $parent(u)$ (把 $u$ 赋值为 $parent(t)$ )。需要注意的是,如果读入一个字符的时候当前状态没有这个字符的出边,就需要在 $parent$ 树上向上跳,直到有这个字符的出边,同时更新 $l$ 。这样的话,删除字符前就要先判断 $l$ 与需要保留的字符串的长度的关系。具体细节可以参考代码及注释。

所以,先读入 $x_i$ 统计答案,再删去首字符读入 $x_i[0]$ 统计答案,删去首字符读入 $x_i[1]$ 统计答案……就只用读入 $O(len(x_i))$ 个字符。

还有一个问题,就是去重。 $s$ 的一个子串可能可以和 $x_i$ 不同的旋转匹配。解决这个问题有两个方法,第一个是求出 $x_i$ 的周期(可以用 kmp 求),第二个方法是在 SAM 上打标记。我用的是打标记的方法,具体细节还是可以参考代码及注释。

参考代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdio>

using namespace std;

const int N=1000010;

struct Node
{
    int len,par,ch[26],vis,cnt;
} sam[N<<1];

void insert(int x);
void read(int x);
void del();
void calc();
void add(int u,int v);
void dfs(int u);

char s[N];
int head[N<<1],nxt[N<<1],to[N<<1],cnt;
int p,tot,n,m,l,u,tim,ans;

int main()
{
    int i;

    scanf("%s%d",s,&n);

    sam[0].par=-1;
    for (i=0;s[i];++i) insert(s[i]-'a');
    for (i=1;i<=tot;++i) add(sam[i].par,i);
    dfs(0);

    for (tim=1;tim<=n;++tim)
    {
        scanf("%s",s);
        m=strlen(s);
        ans=u=l=0;
        for (i=0;i<m;++i) read(s[i]-'a');
        calc();
        for (i=0;i<m-1;++i)
        {
            read(s[i]-'a');
            del();
            calc();
        }
        printf("%d\n",ans);
    }

    return 0;
}

void read(int x) //读入一个字符
{
    while (u&&!sam[u].ch[x]) l=sam[u=sam[u].par].len; //向上跳直至有这个字符的出边
    if (sam[u].ch[x])
    {
        ++l;
        u=sam[u].ch[x];
    }
}

void del() //删除首字符
{
    if (l>m&&--l==sam[sam[u].par].len) u=sam[u].par; //m表示当前xi的长度,只有l>m的时候才删除
}

void calc() //计算当前的答案
{
    if (l==m&&sam[u].vis<tim) //只有当前读入的串长度恰好为m且当前状态没有打上标记时才统计答案
    {
        ans+=sam[u].cnt;
        sam[u].vis=tim; //打标记
    }
}

void insert(int x) //向SAM中插入字符,有人把这个函数叫做extend
{
    int np=++tot;
    sam[np].len=sam[p].len+1;
    sam[np].cnt=1;
    while (~p&&!sam[p].ch[x])
    {
        sam[p].ch[x]=np;
        p=sam[p].par;
    }
    if (p==-1) sam[np].par=0;
    else
    {
        int q=sam[p].ch[x];
        if (sam[q].len==sam[p].len+1) sam[np].par=q;
        else
        {
            int nq=++tot;
            sam[nq].len=sam[p].len+1;
            memcpy(sam[nq].ch,sam[q].ch,sizeof(sam[q].ch));
            sam[nq].par=sam[q].par;
            sam[q].par=sam[np].par=nq;
            while (~p&&sam[p].ch[x]==q)
            {
                sam[p].ch[x]=nq;
                p=sam[p].par;
            }
        }
    }
    p=np;
}

void add(int u,int v) //加边,用于遍历parent树
{
    nxt[++cnt]=head[u];
    head[u]=cnt;
    to[cnt]=v;
}

void dfs(int u) //遍历parent树,计算每个状态的出现次数
{
    int i,v;
    for (i=head[u];i;i=nxt[i])
    {
        v=to[i];
        dfs(v);
        sam[u].cnt+=sam[v].cnt;
    }
}