题解 P3600 【随机数生成器】

xyz32768

2018-10-23 17:56:51

Solution

讲一种不用期望 DP 的做法。 可以考虑对于每个 $i$ ,求出最后的最大值为 $i$ 的概率 $p_i$ ,然后答案就是: $$\sum_{i=1}^x i\times p_i$$ 众所周知,「最大值恰好为 $i$ 」往往是不好处理的。 所以考虑差分:设 $P_i$ 表示最后的最大值 $\le i$ 的概率,这样答案就是: $$\sum_{i=1}^xi\times(P_i-P_{i-1})$$ 现在的问题就是如何求 $P_i$ 。 首先我们知道,如果一个区间包含了其他的区间,那么这个区间对最后的最大值是没有影响的。可以删掉。 这样如果把所有的区间按照左端点排序,那么右端点是严格单调递增的。 发现这个 $P_i$ 其实是一个计数问题(因为方案数除以 $x^n$ 就是概率),设 $h[i]$ 表示最后的最大值 $\le i$ 的**方案数**。 这等价于每个区间的最小值都 $\le i$ ,也就相当于每个区间内都存在一个 $\le i$ 的数。 我们不妨枚举 $j$ 表示有多少个数 $\le i$ ,那么 $h[i]$ 就等于: $$g[j]\times i^j\times (x-i)^{n-j}$$ 其中 $g[j]$ 表示在 $[1,n]$ 内选出 $j$ 个点,使得每个区间至少包含一个点的方案数。 现在可以开始 DP 辣! 定义状态: $f[i][j]$ 表示前 $i$ 个位置放了 $j$ 个点且第 $i$ 个位置必须放点,覆盖所有左端点 $\le i$ 的区间的方案数。 预处理 $fl[i]$ 和 $fr[i]$ 表示覆盖位置 $i$ 的最左区间编号和最右区间编号。特殊地,如果 $i$ 不被任何区间包含,则 $fr[i]$ 为右端点严格小于 $i$ 的最后一个区间, $fl[i]=fr[i]+1$ 。 转移: $$f[i][j]=\sum_{fr[k]+1\ge fl[i]}f[k][j-1]$$ 发现可以参与转移的 $k$ 是一个区间。前缀和优化, $O(n^2)$ 。 $g[j]$ 就很容易算出: $$g[j]=\sum_{fr[i]=q}f[i][j]$$ 复杂度 $O(n^2)$ 。 代码: ```cpp #include <cmath> #include <cstdio> #include <cstring> #include <iostream> #include <algorithm> #define For(i, a, b) for (i = a; i <= b; i++) #define Rof(i, a, b) for (i = a; i >= b; i--) using namespace std; inline int read() { int res = 0; bool bo = 0; char c; while (((c = getchar()) < '0' || c > '9') && c != '-'); if (c == '-') bo = 1; else res = c - 48; while ((c = getchar()) >= '0' && c <= '9') res = (res << 3) + (res << 1) + (c - 48); return bo ? ~res + 1 : res; } template <class T> T Min(T a, T b) {return a < b ? a : b;} template <class T> T Max(T a, T b) {return a > b ? a : b;} const int N = 2005, ZZQ = 666623333, INF = 0x3f3f3f3f; int n, m, X, q, fl[N], fr[N], f[N][N], sum[N][N], g[N], h[N], divis, ans; struct node { int l, r; } a[N], b[N]; bool comp(node a, node b) { return a.l < b.l || (a.l == b.l && a.r < b.r); } int qpow(int a, int b) { int res = 1; while (b) { if (b & 1) res = 1ll * res * a % ZZQ; a = 1ll * a * a % ZZQ; b >>= 1; } return res; } int main() { int i, j, top = 0; n = read(); X = read(); q = read(); For (i, 1, n) fl[i] = INF, fr[i] = -INF; For (i, 1, q) b[i].l = read(), b[i].r = read(); sort(b + 1, b + q + 1, comp); For (i, 1, q) { if (i > 1 && b[i].l == b[i - 1].l) continue; while (top && a[top].r >= b[i].r) top--; a[++top] = b[i]; } q = top; For (i, 1, q) For (j, a[i].l, a[i].r) fl[j] = Min(fl[j], i), fr[j] = Max(fr[j], i); top = 0; For (i, 1, n) if (fl[i] == INF) fl[i] = top + 1, fr[i] = top; else top = Max(top, fr[i]); f[0][0] = sum[0][0] = 1; top = 0; For (i, 1, n) { while (top < i - 1 && fr[top] + 1 < fl[i]) top++; For (j, 1, i) f[i][j] = (sum[i - 1][j - 1] - (top ? sum[top - 1][j - 1] : 0) + ZZQ) % ZZQ; For (j, 0, i) sum[i][j] = (sum[i - 1][j] + f[i][j]) % ZZQ; } For (i, 1, n) if (fr[i] == q) For (j, 1, n) g[j] = (g[j] + f[i][j]) % ZZQ; For (i, 1, X) For (j, 1, n) h[i] = (h[i] + 1ll * qpow(i, j) * qpow(X - i, n - j) % ZZQ * g[j] % ZZQ) % ZZQ; Rof (i, X, 1) h[i] = (h[i] - h[i - 1] + ZZQ) % ZZQ; divis = qpow(qpow(X, n), ZZQ - 2); For (i, 1, X) ans = (ans + 1ll * i * h[i] % ZZQ * divis % ZZQ) % ZZQ; cout << ans << endl; return 0; } ```