题解 P5496 【【模板】回文自动机(PAM)】
Ireliaღ
2019-10-18 08:48:24
# 回文自动机(回文树)
## 简介
### 概述
回文自动机(PAM),也叫回文树
可以用 $O(n)$ 的时间复杂度求出一个字符串的所有回文子串
本蒟蒻是学了两遍才学明白的,这里推荐一下[B站上关于回文自动机的讲解](https://www.bilibili.com/video/av25326779?from=search&seid=16439865067541274096)
当然如果不方便看视频的话,也可以看一下我自己关于回文自动机的一些理解
### 正文
**节点含义**
类比 AC 自动机每个节点的含义
回文自动机每个节点的含义表示在它的父节点两侧各加上一个儿子字符
**奇根偶根**
由于回文串有奇数长度和偶数长度两种
所有我们的回文自动机自然会有两个根——奇根和偶根
偶根的节点编号为 $0$,所代表的回文串的长度为 $0$,`fail` 指针指向奇根
奇根的节点编号为 $1$,所代表的回文串的长度为 $-1$,`fail` 指针指向自身(其实无所谓)
**`fail` 指针**
我们再来说一下 `fail` 指针的含义
一个节点的 `fail` 指针,指向的是这个节点的最长回文后缀
所以在新加入一个字符的时候,我们要从当前节点不断的跳 `fail` 指针
直到跳到某一个节点所表示的回文串的两侧都能扩展一个待添加的字符
我们就看这个节点有没有这个儿子,如果有就直接走下去,没有就新建一个节点
新建节点的长度等于这个节点的长度加上 $2$(因为是回文串,要在两侧各扩展一个字符)
那每个节点的 `fail` 指针要怎么求那
我们可以考虑一个节点的最长回文后缀
必然是在它父节点的某个回文后缀两侧各拓展一个当前字符得到的
所以新建一个节点之后,我们可以从它父亲的 `fail` 节点开始,不断的跳 `fail` 指针
直到跳到第一个两侧能拓展这个字符的节点为止,那么该节点的儿子就是新建节点最长回文后缀
之后我们再看一下奇根和偶根的 `fail` 指针,由于奇根的子节点表示的回文串长度为 $1$,也是就该字符本身
所以奇根相当于是可以向两侧扩展任意字符的,所以我们把偶根的 `fail` 指针指向奇根
而如果跳到了奇根,一定能向两侧扩展,所以奇根的 `fail` 指针自然就无所谓了
代码实现起来非常的简单
**程序实现**
```cpp
char s[MAXN];
struct Node{
int sum, len;//统计每个回文串的出现次数, len表示当前节点回文串的长度
Node *fail, *ch[26];//每个节点的儿子, fail如上所述
Node() {}
}tr[MAXN];
Node *last;
int ncnt;//ncnt表示节点数,last表示当前节点
Node *New(int len) {//新建一个节点
tr[ncnt].len = len;
tr[ncnt].fail = &tr[0];
for (int i = 0; i < 26; i++) tr[ncnt].ch[i] = &tr[0];
return &tr[ncnt++];
}
Node *GetFail(Node *pre, int now) {//跳fail指针
while (s[now - pre->len - 1] != s[now]) pre = pre->fail;
return pre;
}
void Build() {
ncnt = 2;//tr[0] tr[1]为奇数根和偶数根 ,其他节点从2开始
tr[0].len = 0; tr[1].len = -1;//初始化,如上所述
tr[0].fail = &tr[1]; tr[1].fail = &tr[0];//初始化,如上所述
for (int i = 0; i < 26; i++) tr[0].ch[i] = tr[1].ch[i] = &tr[0];
last = &tr[0];
for (int i = 1; s[i]; i++) {
Node *cur = GetFail(last, i);//从当前节点开始,找到可扩展的节点
if (cur->ch[s[i] - 'a'] == &tr[0]) {//没有这个儿子
Node *now = New(cur->len + 2);//新建节点
now->fail = GetFail(cur->fail, i)->ch[s[i] - 'a'];//找到最长回文后缀
cur->ch[s[i] - 'a'] = now;//父子相认
}
last = cur->ch[s[i] - 'a'];//更新last
last->sum++;//顺带求出每个回文串的出现次数
}
}
```
**应用**
如果要求本质不同的回文串的个数,直接输出`cnt−2`即可(除去奇根)
如果要统计每个回文串的出现次数,还要从叶子节点向根遍历一遍
因为我们当时统计回文串时只统计了完整的回文串,但并没有记录它的子串
所以我们要按照拓扑序将每个节点的最长回文后缀的出现次数加上该节点的出现次数
这样我们就得到了一个字符串的所有回文子串的出现次数
```cpp
for (Node *p = tr + ncnt - 1; p != tr + 1; p--) p->fail->sum += p->sum;
```
另外,回文自动机还有一种常见操作就是在构造的时候顺带求出一个 `trans` 指针
`trans` 指针的含义是小于等于当前节点长度的一半最长回文后缀,求法和 `fail` 指针的求法类似
当我们新建一个节点后,如果它的长度小于等于 $2$,那么这个节点的 `trans` 指针指向它的 `fail` 节点
否则的话,我们同理从它父亲的 `trans` 指针指向的节点开始跳 `fail` 指针
直到跳到某一个节点所表示的回文串的两侧都能扩展这个字符
并且拓展后的长度小于等于当前节点长度的一半
那么新建节点的 `trans` 的指针就指向该节点的儿子
代码实现区别也不大
```cpp
void Build() {
ncnt = 2;
tr[0].len = 0; tr[1].len = -1;
tr[0].fail = &tr[1]; tr[1].fail = &tr[0];
for (int i = 0; i < 26; i++) tr[0].ch[i] = tr[1].ch[i] = &tr[0];
last = &tr[0];
for (int i = 1; s[i]; i++) {
Node *cur = GetFail(last, i);
if (cur->ch[s[i] - 'a'] == &tr[0]) {
Node *now = New(cur->len + 2);
now->fail = GetFail(cur->fail, i)->ch[s[i] - 'a'];
cur->ch[s[i] - 'a'] = now;
if (now->len <= 2) {
now->trans = now->fail;
} else {
Node *tmp = cur->trans;
while (s[i - tmp->len - 1] != s[i] || ((tmp->len + 2) << 1) > now->len) tmp = tmp->fail;
now->trans = tmp->ch[s[i] - 'a'];
}
}
last = cur->ch[s[i] - 'a'];
last->sum++;
}
}
```
有了 trans 指针之后,我们就可以很轻松的切了这道[双倍回文](https://www.luogu.org/problemnew/show/P4287)
**后续**
如果 一道题空间限制较小,`PAM`并不能完全代替`Manacher`
## 本题写法
我们只需要注意每次添加新的字符时要处理一下强制在线就OK了,需要求出的就是本质不同的回文串个数。
```cpp
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
const int MAXN = 2e6 + 5;
struct Node{
Node *ch[26], *fail;
int len, sum;
}tr[MAXN];
int n;
char s[MAXN];
int ncnt = 2;
Node *last;
Node *New(int len) {
tr[ncnt].len = len;
tr[ncnt].fail = &tr[0];
for (int i = 0; i < 26; i++) tr[ncnt].ch[i] = &tr[0];
return &tr[ncnt++];
}
Node *GetFail(Node *x, int pos) {
while (s[pos - x->len - 1] != s[pos]) x = x->fail;
return x;
}
void Init() {
tr[0].len = 0; tr[1].len = -1;
tr[0].fail = &tr[1]; tr[1].fail = &tr[0];
for (int i = 0; i < 26; i++) tr[0].ch[i] = tr[1].ch[i] = &tr[0];
last = &tr[0];
}
void Insert(int pos) {
Node *cur = GetFail(last, pos);
if (cur->ch[s[pos] - 'a'] == &tr[0]) {
Node *now = New(cur->len + 2);
now->fail = GetFail(cur->fail, pos)->ch[s[pos] - 'a'];
now->sum = now->fail->sum + 1;
cur->ch[s[pos] - 'a'] = now;
}
last = cur->ch[s[pos] - 'a'];
}
int main() {
cin >> s + 1;
n = strlen(s + 1);
int k = 0;
Init();
for (int i = 1; i <= n; i++) {
s[i] = (s[i] - 97 + k) % 26 + 97;
Insert(i);
cout << last->sum << " ";
k = last->sum;
}
return 0;
}
```
## 感谢
我其实讲解部分转载的是[这篇文章](https://www.cnblogs.com/Vscoder/p/10498625.html),由于[作者](https://www.luogu.org/space/show?uid=121921)是我同学,并且已经退役了,所以我就经过他的同意直接搬运过来了。
![](https://cdn.luogu.com.cn/upload/image_hosting/0i8senwj.png)