P3294 【[SCOI2016]背单词】
communist
2018-10-20 17:36:52
### 阅读理解题
题意见楼上题解
------------
### $Trie$
后缀问题不好处理,我们把它转化为前缀问题,用字典树解决问题
### 贪心
容易想到,一个串的后缀要先于它插入
对于一个串和其若干后缀串,容易想到,我们要先插入后缀串
然后递归进入$size$最小的子串
```
bool cmp(const int &x,const int &y)
{
return size[x]<size[y];
}
void makes(int x)
{
size[x]=1;
for(int i=0;i<t[x].size();i++)
{
makes(t[x][i]);
size[x]+=size[t[x][i]];
}
sort(t[x].begin(),t[x].end(),cmp);
}
void dfs(int x)
{
id[x]=tot++;
for(int i=0;i<t[x].size();i++)
{
ans+=tot-id[x];
dfs(t[x][i]);
}
}
```
### 注意
求$size$要重构树,只保留关键点
### 和楼上不一样的地方
#### 因为我太蒻了,并不会指针,所以提供一个并查集重构树的方法
在建$Trie$时给所有串的结尾和$Trie$树的根节点标号,表示新树中点的编号
```
void insert(const string &s,int id)
{
int now=0,l=len[id];
for(int i=0;i<l;i++)
{
int c=idx(s[i]);
now=tr[now][c]?tr[now][c]:tr[now][c]=++cnt;
}
val[now]=id;
}
```
然后遍历$Trie$树,如果一个节点的子节点没有被标号,就把它并入当前节点的集合;否则把这个子节点作为当前节点所在集合的根的儿子(就是连一条边)
```
void make(int x)
{
for(int v,i=0;i<26;i++)
if(v=tr[x][i])
{
if(!val[v])
f[v]=find(x);
else
t[val[find(x)]].push_back(val[v]);
make(v);
}
}
```
### 代码:
```
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define int long long
using namespace std;
const int maxl=510010,maxn=1e5+10;
int n,tr[maxl][30],val[maxl],cnt,len[maxn],size[maxn],tot,f[maxl],id[maxn],ans;
vector<int>t[maxn];
string st[maxn];
inline int find(int x)
{
return x==f[x]?x:f[x]=find(f[x]);
}
inline int idx(char c)
{
return c-'a';
}
void insert(const string &s,int id)
{
int now=0,l=len[id];
for(int i=0;i<l;i++)
{
int c=idx(s[i]);
now=tr[now][c]?tr[now][c]:tr[now][c]=++cnt;
}
val[now]=id;
}
void make(int x)
{
for(int v,i=0;i<26;i++)
if(v=tr[x][i])
{
if(!val[v])
f[v]=find(x);
else
t[val[find(x)]].push_back(val[v]);
make(v);
}
}
bool cmp(const int &x,const int &y)
{
return size[x]<size[y];
}
void makes(int x)
{
size[x]=1;
for(int i=0;i<t[x].size();i++)
{
makes(t[x][i]);
size[x]+=size[t[x][i]];
}
sort(t[x].begin(),t[x].end(),cmp);
}
void dfs(int x)
{
id[x]=tot++;
for(int i=0;i<t[x].size();i++)
{
ans+=tot-id[x];
dfs(t[x][i]);
}
}
signed main()
{
scanf("%lld",&n);
for(int i=1;i<=n;i++)
{
cin>>st[i];
len[i]=st[i].length();
for(int j=0;j<len[i]/2;j++)
swap(st[i][j],st[i][len[i]-j-1]);
insert(st[i],i);
}
for(int i=1;i<=cnt;i++)
f[i]=i;
make(0),makes(0),dfs(0);
printf("%lld\n",ans);
return 0;
}
```