【PKUWC2018】随机游走

2018-12-25 20:45:55


题目描述

给定一棵 $n$ 个结点的树,你从点 $x$ 出发,每次等概率随机选择一条与所在点相邻的边走过去。

有 $Q$ 次询问,每次询问给定一个集合 $S$ ,求如果从 $x$ 出发一直随机游走,直到点集 $S$ 中所有点都至少经过一次的话,期望游走几步。

特别地,点 $x$ (即起点)视为一开始就被经过了一次。

答案对 $998244353 $ 取模。

$1\leq n\leq 18$

$1\leq Q\leq 5000$

$1\leq k\leq n$

题解

又是一个 min-max 容斥 的板子……

显然有:

$E(\max(S)) = \sum_{T \subseteq S} (-1)^{\mid T \mid + 1}E(\min(T))$

考虑 $E(\min(T))$ 是什么意思呢……就是从 $x$ 出发,经过 $T$ 中的点各至少一次的期望步数

于是可以枚举 $S$ ,设 $f_u$ 表示从 $u$ 出发,经过 $S$ 中的点各至少一次的期望步数

于是有:

$$f_u=1+\frac{1}{deg_u}\sum_{u \to v} f_v$$

然后就可以高斯消元了,由于每次的枚举集合中的有效点总共不多,卡卡常就可以过了

注意一下这道题的特殊性,是一棵 ,于是可以把 $f_u$ 表示成 $a_u f_{fa_u}+b_u$ 的形式

之后直接随便搞搞就搞出来了……

设 $g_{T}$ 表示钦定的集合为 $T$ 时,从 $x$ 出发,经过 $T$ 中的所有点至少一次的期望步数

对于查询的一个询问 $S$ 来说,答案就是:

$$\sum_{T \subseteq S} (-1)^{\mid T \mid + 1} g_{T}$$

于是可以先 $FMT$ 一下,处理出后面那个的子集和,查询就可以 $O(1)$ 了

总的时间复杂度为 $O(n 2^n \log P+q)$

带上一个 $\log P$ 是因为在预处理的时候要用到一个求逆元

代码

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N = 20, mod = 998244353;

vector<int> g[N];

int n, q, x; ll ans[1 << N], deginv[N];

ll pw(ll a, ll b) {
    ll r = 1;
    for( ; b ; b >>= 1, a = a * a % mod)
        if(b & 1)
            r = r * a % mod;
    return r;
}

struct T {
    ll a, b;
    // f[u] = a*f[fa]+b
    T(ll a = 0, ll b = 0): a(a), b(b) {}
    T operator + (T t) {
        return (T) { (a + t.a) % mod, (b + t.b) % mod };
    }
} f[1 << N][N];

void dfs(int u, int fa, T *f, int S) {
    if(S & (1 << (u - 1))) return ;
    T sum;
    for(int v: g[u]) {
        if(v == fa) continue;
        dfs(v, u, f, S);
        sum = sum + f[v];
    }
    ll tmp = pw(1 - deginv[u] * sum.a % mod, mod - 2);
    f[u].a = deginv[u] * tmp % mod;
    f[u].b = (1 + sum.b * deginv[u] % mod) * tmp % mod;
}

int cnt[1 << 20];

int main() {
    ios :: sync_with_stdio(0);
    cin >> n >> q >> x;
    for(int i = 1, u, v ; i < n ; ++ i)
        cin >> u >> v,
        g[u].push_back(v),
        g[v].push_back(u);
    for(int i = 1 ; i <= n ; ++ i)
        deginv[i] = pw(g[i].size(), mod - 2);
    for(int s = 0 ; s < (1 << n) ; ++ s)
        cnt[s] = cnt[s >> 1] + (s & 1);
    for(int s = 1 ; s < (1 << n) ; ++ s)
        dfs(x, 0, f[s], s),
        ans[s] = (cnt[s] & 1 ? 1 : -1) * f[s][x].b % mod;
    for(int i = 1 ; i <= n ; ++ i)
        for(int s = 0 ; s < (1 << n) ; ++ s)
            if(s & (1 << (i - 1)))
                (ans[s] += ans[s - (1 << (i - 1))]) %= mod;
    for(int i = 1 ; i <= q ; ++ i) {
        int k, x = 0, y; cin >> k;
        for(int j = 1 ; j <= k ; ++ j)
            cin >> y,
            x |= 1 << (y - 1);
        cout << (ans[x] % mod + mod) % mod << endl;
    }
}