题解 CF1158F 【Density of subarrays】
ywy_c_asm
2019-06-09 19:27:42
纪念一下发的第100篇题解……
首先我们应该知道,有值的$p$至多是$\lfloor\frac n c\rfloor$,因为一个密度为$p$的序列最短的长度就是$pc$,此时它可以划分为$p$段,每段都是$1\cdots c$的一个排列。
然后就有这么一个$O(\frac{n^3}c)$的暴力,我们发现所有$c^p$个序列中,分别以$1\cdots c$为开头的各有$c^{p-1}$个,那么也就是说我们可以从子问题划分的角度去考虑,首先这个序列里肯定得有$1\cdots c$所有的,我们考虑每种数的第一个出现的位置,显然在它之后得有个包含了所有$c^{p-1}$的子序列(即密度至少是$p-1$),其实不难发现我们只需要让最晚的第一个出现的位置之后有一个密度至少是$p-1$的子序列就行了。
那么这样的话我们可以想到一个dp,$dp[i][j]$表示**以$i$为开头**(这一点在子序列的dp中要尤其注意)的密度**至少**为j的子序列个数,我们最后求出来密度至少为$p$的子序列有多少个,做个差分就行了。在这种子序列dp中我们必须要慎重,稍不留神就会算重,那么我们考虑对于一个确定的子序列如何使用dp的方式去**确定**它,刚才都说过了,我们可以考虑最晚的一个第一次出现的位置,也就是说,在这个位置之前的这个子序列的前缀第一次出现了全部的$1\cdots c$,我们考虑去枚举它。我们可以预处理一个$f[l,r]$表示开头为$l$,结尾为$r$,有多少子序列满足$1\cdots c$全部出现并且$a_r$仅在$r$位置出现了一次。这个容易,我们维护一个cnt数组,那么就是$[a_l\not=a_r]2^{cnt_{a_l}-1}\prod_{i\not=a_l\&\&i\not=a_r}(2^{cnt_{i}}-1)$,这个可以用逆元什么的$O(n^2)$算。于是我们就可以转移了,我们做个前缀和$sum[i][j]=\sum_{k>=i}dp[k][j]$,那么$dp[i][j]=\sum_{k=i+1}^nf[i,k]sums[k+1][j-1]$。
然而当c特别小的时候这个暴力还是过不去,我们考虑另一个c越小越好的暴力(显然咱们要对数据分治……你发现上面那个暴力的复杂度的c放到了分母上那肯定就要这样啊……),我们发现上面那个dp实际上是在干这个事,就是**贪心的把一个子序列划分为若干段,如果当前段一旦开始出现了全部的c那么就赶紧分段**,那么我们可以在c特别小的时候把序列中出现的数状压一下,即$dp[i][j][S]$表示前i个(不一定以i结尾),现在分了j次,最后一段里面出现的数字集合为$S$,这个dp就非常显然了,我们只要在S满了的时候强行转移到$dp[i+1][j+1][0]$即可。注意这个的$j$并不是至少而是恰好,就不用做差分了。
然后我们在$c<=\log n$的时候用暴力2,否则用暴力1,考虑这样的复杂度,暴力1是$O(\frac{n^3}{\log n})$的,暴力2是$O(n2^n\frac{n}{c})$的,而$\frac{2^c}c$由于随着c的增大而增大,所以它$<=\frac n{\log n}$,所以也是$O(\frac{n^3}{\log n})$的。
然后$c=1$就用组合数特判一下就行了。
上代码~
```cpp
#include <iostream>
#include <cstdio>
#include <cstring>
#define ll unsigned long long
#define int unsigned int
#define p 998244353
#define md(_o) ((_o >= p) ? (_o)-p : (_o))
using namespace std;
namespace ywy {
inline ll mi(int a, int b) {
ll ans = 1, tmp = a;
while (b) {
if (b & 1)
ans = (ans * tmp) % p;
tmp = (tmp * tmp) % p;
b >>= 1;
}
return (ans);
}
ll dk[3333], dk1[3333], invdk[3333], invdk1[3333];
int ints[3333], cnt[3333];
ll sums[3333][3333], g[3333];
int f[3333][3333], dp[3333][3333];
namespace csmall {
int dp[2][3333][1024];
void solve(int n, int c) {
dp[0][0][0] = 1;
for (register int i = 1; i <= n; i++) {
for (register int j = 0; j <= i && j <= n / c; j++) {
for (register int k = 0; k + 1 < (1 << c); k++) dp[i & 1][j][k] = 0;
}
for (register int j = 0; j <= i && j <= n / c; j++) {
for (register int k = 0; k + 1 < (1 << c); k++) {
if (!dp[(i & 1) ^ 1][j][k])
continue;
dp[i & 1][j][k] = md(dp[i & 1][j][k] + dp[(i & 1) ^ 1][j][k]);
if ((k | (1 << (ints[i] - 1))) == (1 << c) - 1) {
dp[i & 1][j + 1][0] = md(dp[i & 1][j + 1][0] + dp[(i & 1) ^ 1][j][k]);
} else {
dp[i & 1][j][k | (1 << (ints[i] - 1))] =
md(dp[i & 1][j][k | (1 << (ints[i] - 1))] + dp[(i & 1) ^ 1][j][k]);
}
}
}
}
g[0] = p - 1;
for (register int i = 0; i <= n; i++) {
for (register int j = 0; j < (1 << c); j++) {
g[i] = (g[i] + dp[n & 1][i][j]) % p;
}
printf("%lld ", g[i]);
}
}
}
void ywymain() {
int n, c;
cin >> n >> c;
for (register int i = 1; i <= n; i++) cin >> ints[i];
if (c == 1) {
f[0][0] = 1;
for (register int i = 1; i <= n; i++) {
f[i][0] = 1;
for (register int j = 1; j <= i; j++) f[i][j] = (f[i - 1][j] + f[i - 1][j - 1]) % p;
}
printf("0 ");
for (register int i = 1; i <= n; i++) printf("%lld ", f[n][i]);
return;
}
dp[n + 1][0] = 1;
dk[0] = 1;
invdk[0] = 1;
invdk1[0] = 1;
for (register int i = 1; i <= n; i++) {
dk[i] = (dk[i - 1] << 1) % p;
dk1[i] = dk[i] - 1;
invdk[i] = mi(dk[i], p - 2);
invdk1[i] = mi(dk1[i], p - 2);
}
if (c <= 10) {
csmall::solve(n, c);
return;
}
for (register int l = 1; l <= n; l++) {
memset(cnt, 0, sizeof(cnt));
int tot = 1;
ll ji = 1;
cnt[ints[l]] = 1;
for (register int r = l + 1; r <= n; r++) {
if (ints[r] == ints[l]) {
ji *= 2;
ji %= p;
} else {
ji *= invdk1[cnt[ints[r]]];
ji %= p;
cnt[ints[r]]++;
ji *= dk1[cnt[ints[r]]];
ji %= p;
if (cnt[ints[r]] == 1)
tot++;
}
if (tot == c && ints[r] != ints[l])
f[l][r] = (ji * invdk1[cnt[ints[r]]]) % p;
}
}
sums[n + 1][0] = 1;
for (register int i = n; i >= 1; i--) {
for (register int j = 1; j <= n - i + 1; j++) {
ll tot = 0;
for (register int k = i + c - 1; k <= n && sums[k + 1][j - 1]; k++)
tot = (tot + (ll)f[i][k] * sums[k + 1][j - 1]) % p;
if (!tot)
break;
dp[i][j] = tot;
}
dp[i][0] = dk[n - i];
for (register int j = 0; j <= n - i + 1; j++) sums[i][j] = (sums[i + 1][j] + dp[i][j]) % p;
}
sums[1][0]--;
for (register int i = 0; i <= n; i++) printf("%lld ", (sums[1][i] + p - sums[1][i + 1]) % p);
}
}
signed main() {
ywy::ywymain();
return (0);
}
```