CF149D Coloring Brackets(区间dp,记忆化搜索)

hhz6830975

2018-02-06 14:26:23

Solution

刷区间dp的时候刷到的好题,安利一波这个系列:https://www.luogu.org/blog/hhz6830975/kuangbin-dai-ni-fei-zhuan-ti-er-shi-er-ou-jian-dp 设$dp[l][r][x][y]$为区间$[l,r]$两端颜色分别为$x$、$y$的方案数(0为不染色,1为红色,2为蓝色) 对于$dp[l][r][x][y]$,<font color=#A52A2A>若保证$[l,r]$为一个合法的序列</font>,则: 1.若$l+1=r$,显然序列为 $$()$$ 则有 ```cpp dp[l][r][0][1]=dp[l][r][0][2]=dp[l][r][1][0]=dp[l][r][2][0]=1; ``` 2.若括号$l$与$r$配对,则对应的序列为 $$(...)$$ 那么有 ```cpp for(int i=0;i<=2;i++) for(int j=0;j<=2;j++){ if(j!=1) dp[l][r][0][1]=(dp[l][r][0][1]+dp[l+1][r-1][i][j])%mod; if(j!=2) dp[l][r][0][2]=(dp[l][r][0][2]+dp[l+1][r-1][i][j])%mod; if(i!=1) dp[l][r][1][0]=(dp[l][r][1][0]+dp[l+1][r-1][i][j])%mod; if(i!=2) dp[l][r][2][0]=(dp[l][r][2][0]+dp[l+1][r-1][i][j])%mod; } ``` 3.若括号$l$与$r$不配对,则对应的序列为 $$(...)...(...)$$ 设与括号$l$配对的括号为$match[l]$,则有 ```cpp for(int i=0;i<=2;i++) for(int j=0;j<=2;j++) for(int p=0;p<=2;p++) for(int q=0;q<=2;q++){ if((j==1&&p==1)||(j==2&&p==2)) continue; dp[l][r][i][q]=(dp[l][r][i][q]+dp[l][match[l]][i][j]*dp[match[l]+1][r][p][q]%mod)%mod; } ``` 以上为状态转移方式 需要注意的是,本题不采用一般的dp顺序,而是使用记忆化搜索 原因在于本题要求配对的括号有且只有一个染色,但按一般的方式难以保证配对的括号在同一区间内,即不能保证转移中用到的序列都是合法的,处理比较麻烦 而采用记忆化搜索的方式,所有操作都是在原来的合法序列上增加/删除一对括号,或是将合法序列分割成两个合法序列/将两个合法序列合并,得到的仍然是合法序列,处理起来就方便得多;而且还不需要考虑不合法的状态,时间上也得到优化 代码: ```cpp #include <iostream> #include <cstdio> #include <cstring> #include <stack> using namespace std; const int maxn=710; const int mod=1e9+7; int n,match[maxn]; long long dp[maxn][maxn][3][3]; char s[maxn]; stack <int> St; void dfs(int l,int r){ if(l+1==r) dp[l][r][0][1]=dp[l][r][0][2]=dp[l][r][1][0]=dp[l][r][2][0]=1; else if(match[l]==r){ dfs(l+1,r-1); for(int i=0;i<=2;i++) for(int j=0;j<=2;j++){ if(j!=1) dp[l][r][0][1]=(dp[l][r][0][1]+dp[l+1][r-1][i][j])%mod; if(j!=2) dp[l][r][0][2]=(dp[l][r][0][2]+dp[l+1][r-1][i][j])%mod; if(i!=1) dp[l][r][1][0]=(dp[l][r][1][0]+dp[l+1][r-1][i][j])%mod; if(i!=2) dp[l][r][2][0]=(dp[l][r][2][0]+dp[l+1][r-1][i][j])%mod; } } else{ dfs(l,match[l]),dfs(match[l]+1,r); for(int i=0;i<=2;i++) for(int j=0;j<=2;j++) for(int p=0;p<=2;p++) for(int q=0;q<=2;q++){ if((j==1&&p==1)||(j==2&&p==2)) continue; dp[l][r][i][q]=(dp[l][r][i][q]+dp[l][match[l]][i][j]*dp[match[l]+1][r][p][q]%mod)%mod; } } } int main(){ scanf("%s",s); n=strlen(s); for(int i=0;i<n;i++){ if(s[i]=='(') St.push(i); else match[St.top()]=i,match[i]=St.top(),St.pop(); } dfs(0,n-1); long long ans=0; for(int i=0;i<=2;i++) for(int j=0;j<=2;j++) ans=(ans+dp[0][n-1][i][j])%mod; printf("%lld\n",ans); return 0; } ```