题解 P3311 【[SDOI2014]数数】AC自动机-dp

wzj423

2018-07-17 15:30:24

Solution

看到这道题,很容易想到以下AC自动机上的dp方程: ```cpp /** f[0][i][j]表示该位不受限制,AC自动机节点为i,匹配到了第j位的方案数 f[1][i][j]表示该位受限制,AC自动机节点为i,匹配到了第j位的方案数 f[2][i][j]表示该位以及之前都是前导零, 因而有: for(int nx=0;nx<10;++nx) f[0][i][j]->f[0][c[i][nx]][j+1] for(int nx=0;nx<N[j];++nx) f[1][i][j]->f[0][c[i][nx]][j+1] f[1][i][j]->f[1][c[i][nx nx==N[j] ][j+1] */ ``` 但这样只能得到80分,会Wa两个点 这是为什么?我们可以看一看数据: ```cpp 0821463255 0383320100380 001153611528 066512243919 0321614037728 0835268196243 (4.in) ``` 这时835268196243XXXXX在AC自动机上的路径为0-8-3-5... 这本来是一组可行解,但因为错误地考虑了前导零而排除了 因而我们需要重新设计方程: ```cpp /** f[0][i][j]表示该位不受限制,AC自动机节点为i,匹配到了第j位的方案数 f[1][i][j]表示该位受限制,AC自动机节点为i,匹配到了第j位的方案数 f[2][i][j]表示该位以及之前都是前导零, AC自动机节点为i,匹配到了第j位的方案数 则有转移: for(int nx=0;nx<10;++nx) f[0][i][j]->f[0][c[i][nx]][j+1] for(int nx=0;nx<N[j];++nx) f[1][i][j]->f[0][c[i][nx]][j+1] f[1][i][j]->f[1][c[i][nx nx==N[j] ][j+1] f[1][0][0]->f[2][0][1] nx==0 f[1][0][0]->f[0][c[0][nx]][1] nx!=0 f[2][0][j]->f[2][0][j+1] nx==0 f[2][0][j]->f[0][c[0][i][j+1] */ ``` 这样就能AC了 80pts Ver: ```cpp #include <bits/stdc++.h> /*80*/ using namespace std; //defs============================== const int mod=1e9+7,MAXP=1700,MAXN=1300; int M; char N[MAXN],p[MAXN];int Nlen; long long ans; bool tot0; //AC========================================== int c[MAXP][10],val[MAXP],fail[MAXP],cnt,_rt=0; bool is_tot0(char *s) { int len=strlen(s); for(int i=0;i<len;++i) if(s[i]!='0') return false; return true; } void ins(char *s) { int len=strlen(s),now=_rt; for(int i=0;i<len;++i) { int v=s[i]-'0'; if(!c[now][v]) c[now][v]=++cnt; now=c[now][v]; } ++val[now]; } void build_fail() { queue<int> q; for(int i=0;i<10;++i) if(c[0][i]) q.push(c[0][i]); while(!q.empty()) { int u=q.front();q.pop(); for(int i=0;i<10;++i) if(c[u][i]) fail[c[u][i]]=c[fail[u]][i],val[c[u][i]]+=val[fail[c[u][i]]],q.push(c[u][i]); else c[u][i]=c[fail[u]][i]; } } //dp====================== int f[2][MAXP][MAXN]; /* f[0][i][j]表示该位不受限制,AC自动机节点为i,匹配到了第j位的方案数 f[1][i][j]表示该位受限制,AC自动机节点为i,匹配到了第j位的方案数 f[2][i][j]表示该位以及之前都是前导零, for(int nx=0;nx<10;++nx) f[0][i][j]->f[0][c[i][nx]][j+1] for(int nx=0;nx<N[j];++nx) f[1][i][j]->f[0][c[i][nx]][j+1] f[1][i][j]->f[1][c[i][nx nx==N[j] ][j+1] f[1][0][0]->f[2][0][1] nx==0 f[1][0][0]->f[0][c[0][nx]][1] nx!=0 f[2][0][j]->f[2][0][j+1] nx==0 f[2][0][j]->f[0][c[0][i][j+1] */ int vis[2][MAXP][MAXN]; void pp(int i,int j,int k) { // printf("f[%d][%d][%d]=%d\n",i,j,k,f[i][j][k]); } //main================================= int main() { scanf("%s",N+1);Nlen=strlen(N+1); scanf("%d",&M); for(int i=1;i<=M;++i) scanf("%s",p),ins(p),tot0|=is_tot0(p); build_fail(); //for(int i=0;i<=cnt;++i) // printf("id=%d 0=%d 1=%d 2=%d 3=%d 4=%d val=%d fail=%d\n",i,c[i][0],c[i][1],c[i][2],c[i][3],c[i][4],val[i],fail[i]); vis[1][0][0]=f[1][0][0]=1; for(int j=0;j<Nlen;++j){ for(int i=0;i<=cnt;++i) { for(int limit=0;limit<=1;++limit) { if(!vis[limit][i][j]) continue; // printf("from (%d,%d,%d)\n",limit,i,j); if(limit==0) { //puts("limit=0"); for(int nx=0;nx<10;++nx) if(!val[c[i][nx]]) (f[0][c[i][nx]][j+1]+=f[0][i][j])%=mod,vis[0][c[i][nx]][j+1]=true,pp(0,c[i][nx],j+1); } else { //puts("limit=1"); for(int nx=0;nx<N[j+1]-'0';++nx) if( /*!(j==0&&nx==0) &&*/ !val[c[i][nx]]) (f[0][c[i][nx]][j+1]+=f[1][i][j])%=mod,vis[0][c[i][nx]][j+1]=true,pp(0,c[i][nx],j+1); if(!val[c[i][ N[j+1]-'0' ]]) { (f[1][c[i][ N[j+1]-'0' ]][j+1]+=f[1][i][j])%=mod; vis[1][ c[i][ N[j+1]-'0' ] ][j+1]=true; //puts("special"); pp(1,c[i][N[j+1]-'0'],j+1); } } } } } for(int i=0;i<=cnt;++i) { (ans+=f[0][i][Nlen])%=mod; (ans+=f[1][i][Nlen])%=mod; } if(tot0) printf("%lld\n",ans); else printf("%lld\n",ans-1); return 0; } ``` 100pts Ver ```cpp #include <bits/stdc++.h> using namespace std; //defs============================== const int mod=1e9+7,MAXP=1700,MAXN=1300; int M; char N[MAXN],p[MAXN];int Nlen; long long ans; bool tot0; //AC========================================== int c[MAXP][10],val[MAXP],fail[MAXP],cnt,_rt=0; bool is_tot0(char *s) { int len=strlen(s); for(int i=0;i<len;++i) if(s[i]!='0') return false; return true; } void ins(char *s) { int len=strlen(s),now=_rt; for(int i=0;i<len;++i) { int v=s[i]-'0'; if(!c[now][v]) c[now][v]=++cnt; now=c[now][v]; } ++val[now]; } void build_fail() { queue<int> q; for(int i=0;i<10;++i) if(c[0][i]) q.push(c[0][i]); while(!q.empty()) { int u=q.front();q.pop(); for(int i=0;i<10;++i) if(c[u][i]) fail[c[u][i]]=c[fail[u]][i],val[c[u][i]]+=val[fail[c[u][i]]],q.push(c[u][i]); else c[u][i]=c[fail[u]][i]; } } //dp====================== int f[3][MAXP][MAXN]; /* f[0][i][j]表示该位不受限制,AC自动机节点为i,匹配到了第j位的方案数 f[1][i][j]表示该位受限制,AC自动机节点为i,匹配到了第j位的方案数 f[2][i][j]表示该位以及之前都是前导零, for(int nx=0;nx<10;++nx) f[0][i][j]->f[0][c[i][nx]][j+1] for(int nx=0;nx<N[j];++nx) f[1][i][j]->f[0][c[i][nx]][j+1] f[1][i][j]->f[1][c[i][nx nx==N[j] ][j+1] f[1][0][0]->f[2][0][1] nx==0 f[1][0][0]->f[0][c[0][nx]][1] nx!=0 f[2][0][j]->f[2][0][j+1] nx==0 f[2][0][j]->f[0][c[0][i]][j+1] nx!=0 */ int vis[3][MAXP][MAXN]; void pp(int i,int j,int k) { //printf("f[%d][%d][%d]=%d\n",i,j,k,f[i][j][k]); } inline void add(int &a,int b) { (a+=b)%=mod; } //main================================= int main() { scanf("%s",N+1);Nlen=strlen(N+1); scanf("%d",&M); for(int i=1;i<=M;++i) scanf("%s",p),ins(p),tot0|=is_tot0(p); build_fail(); //for(int i=0;i<=cnt;++i) // printf("id=%d 0=%d 1=%d 2=%d 3=%d 4=%d val=%d fail=%d\n",i,c[i][0],c[i][1],c[i][2],c[i][3],c[i][4],val[i],fail[i]); vis[1][0][0]=f[1][0][0]=1; for(int j=0;j<Nlen;++j){ for(int i=0;i<=cnt;++i) { for(int limit=0;limit<=2;++limit) { if(!vis[limit][i][j]) continue; if(limit==0) { for(int nx=0;nx<10;++nx) if(!val[c[i][nx]]) add(f[0][c[i][nx]][j+1],f[0][i][j]),vis[0][c[i][nx]][j+1]=true,pp(0,c[i][nx],j+1); } else if(limit==1){ for(int nx=(j==0);nx<N[j+1]-'0';++nx) if(!val[c[i][nx]]) { add(f[0][c[i][nx]][j+1],f[1][i][j]); vis[0][c[i][nx]][j+1]=true; pp(0,c[i][nx],j+1); } if(j==0){ add(f[2][0][1],f[1][0][0]); vis[2][0][1]=true; pp(2,0,1); } if(!val[c[i][ N[j+1]-'0' ]]) { add(f[1][ c[i][ N[j+1]-'0' ]][j+1],f[1][i][j]); vis[1][ c[i][ N[j+1]-'0' ] ][j+1]=true; pp(1,c[i][N[j+1]-'0'],j+1); } } else if(limit==2) { add(f[2][0][j+1],f[2][0][j]); vis[2][0][j+1]=1; pp(2,0,j+1); for(int nx=1;nx<10;++nx) if(!val[c[i][nx]]) { add(f[0][c[i][nx]][j+1],f[2][0][j]); vis[0][c[i][nx]][j+1]=true; pp(0,c[i][nx],j+1); } } } } } for(int i=0;i<=cnt;++i) { (ans+=f[0][i][Nlen])%=mod; (ans+=f[1][i][Nlen])%=mod; } printf("%lld\n",ans); return 0; } ```