P3294 【[SCOI2016]背单词】

communist

2018-10-20 17:36:52

Solution

### 阅读理解题 题意见楼上题解 ------------ ### $Trie$ 后缀问题不好处理,我们把它转化为前缀问题,用字典树解决问题 ### 贪心 容易想到,一个串的后缀要先于它插入 对于一个串和其若干后缀串,容易想到,我们要先插入后缀串 然后递归进入$size$最小的子串 ``` bool cmp(const int &x,const int &y) { return size[x]<size[y]; } void makes(int x) { size[x]=1; for(int i=0;i<t[x].size();i++) { makes(t[x][i]); size[x]+=size[t[x][i]]; } sort(t[x].begin(),t[x].end(),cmp); } void dfs(int x) { id[x]=tot++; for(int i=0;i<t[x].size();i++) { ans+=tot-id[x]; dfs(t[x][i]); } } ``` ### 注意 求$size$要重构树,只保留关键点 ### 和楼上不一样的地方 #### 因为我太蒻了,并不会指针,所以提供一个并查集重构树的方法 在建$Trie$时给所有串的结尾和$Trie$树的根节点标号,表示新树中点的编号 ``` void insert(const string &s,int id) { int now=0,l=len[id]; for(int i=0;i<l;i++) { int c=idx(s[i]); now=tr[now][c]?tr[now][c]:tr[now][c]=++cnt; } val[now]=id; } ``` 然后遍历$Trie$树,如果一个节点的子节点没有被标号,就把它并入当前节点的集合;否则把这个子节点作为当前节点所在集合的根的儿子(就是连一条边) ``` void make(int x) { for(int v,i=0;i<26;i++) if(v=tr[x][i]) { if(!val[v]) f[v]=find(x); else t[val[find(x)]].push_back(val[v]); make(v); } } ``` ### 代码: ``` #include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<vector> #define int long long using namespace std; const int maxl=510010,maxn=1e5+10; int n,tr[maxl][30],val[maxl],cnt,len[maxn],size[maxn],tot,f[maxl],id[maxn],ans; vector<int>t[maxn]; string st[maxn]; inline int find(int x) { return x==f[x]?x:f[x]=find(f[x]); } inline int idx(char c) { return c-'a'; } void insert(const string &s,int id) { int now=0,l=len[id]; for(int i=0;i<l;i++) { int c=idx(s[i]); now=tr[now][c]?tr[now][c]:tr[now][c]=++cnt; } val[now]=id; } void make(int x) { for(int v,i=0;i<26;i++) if(v=tr[x][i]) { if(!val[v]) f[v]=find(x); else t[val[find(x)]].push_back(val[v]); make(v); } } bool cmp(const int &x,const int &y) { return size[x]<size[y]; } void makes(int x) { size[x]=1; for(int i=0;i<t[x].size();i++) { makes(t[x][i]); size[x]+=size[t[x][i]]; } sort(t[x].begin(),t[x].end(),cmp); } void dfs(int x) { id[x]=tot++; for(int i=0;i<t[x].size();i++) { ans+=tot-id[x]; dfs(t[x][i]); } } signed main() { scanf("%lld",&n); for(int i=1;i<=n;i++) { cin>>st[i]; len[i]=st[i].length(); for(int j=0;j<len[i]/2;j++) swap(st[i][j],st[i][len[i]-j-1]); insert(st[i],i); } for(int i=1;i<=cnt;i++) f[i]=i; make(0),makes(0),dfs(0); printf("%lld\n",ans); return 0; } ```