朝田诗乃 的博客

朝田诗乃 的博客

题解 P4229 【某位歌姬的故事】

posted on 2019-01-31 11:59:38 | under 题解 |

更好的阅读体验点这里QAQ

看到 $n$ 的范围首先想到离散化。按照套路将数据转化成几个限制区间的形式,对于每个限制区间维护区间编号 $id$ 、区间长度 $len$ 以及区间中可以达到的最大音高 $lim$ 。

然后把这一堆东西拿去 $dp$ 。这里dp方法有很多,我的是这样子的:

设 $f[i][j]$ 表示表示满足了前 $i$ 个区间,处理到了位置 $j$ ,且 $j$ 上放的是最大值,并且 $j$ 后不能再出现最大值的方案数。

列出方程后可以发现转移是 $O(n)$ 的,总复杂度就是 $O(n^2logn)$ [有一个快速幂]原地起飞。

使用前缀和优化将转移变成 $O(1)$ ,总复杂度变为 $O(nlogn)$ 可过。

#include <algorithm>
#include <iostream>
#include <cstdio>
#include <cstring>
#include <string>
#include <cmath>
#include <cstdlib>
using namespace std;
const int MAXM = 550, MAXN = 2050, M = 998244353;
typedef long long lint;
int read(int &x) {
    char ch; int f = 1;
    while(ch = getchar(), ch < '!' && ch != '-') if(ch == '-') f = -1; x = ch - 48;
    while(ch = getchar(), ch > '!') x = (x << 3) + (x << 1) + ch - 48; x *= f; return x;
}
struct Que {int l, r, h;} q[MAXM], tmp[MAXM];
int cmp(Que a, Que b) {return a.h == b.h ? a.r < b.r : a.h < b.h;}
int len[MAXN], pos[MAXN], dd[MAXN], val[MAXN], cntdd, id[MAXN], cntseg, cnttl, cntt;
int f[2][MAXN], sf[2][MAXN], inv[MAXN], slen[MAXN], lim[MAXN], n, T, m, Q, cntval, tmplen[MAXN];
int power(int a, int b) {
    int res = 1;
    for(; b; b >>= 1, a = (1ll * a * a) % M)
        if(b & 1) res = (1ll * res * a) % M;
    return res;
}
int Inv(int x) {return power(x, M - 2);}
void trans(int &mul, int h) {
    if(h == 1) return;
    memset(f, 0, sizeof f); memset(sf, 0, sizeof sf);
    int cur = 0;
    f[0][0] = sf[0][0] = inv[0] = 1;
    for(int i = 1; i <= cnttl; ++i)
        sf[cur][i] = sf[cur][i-1], slen[i] = slen[i-1] + tmplen[i], inv[i] = Inv(power(h-1, slen[i]));
    tmp[0].r = 0;
    for(int i = 1; i <= cntt; ++i) {
        cur ^= 1;
        for(int j = 0; j <= tmp[i].l - 1; ++j) f[cur][j] = sf[cur][j] = 0;
        for(int j = tmp[i].l; j <= tmp[i].r; ++j) {
            f[cur][j] = f[cur ^ 1][j];
            if(j > tmp[i-1].r) {
                f[cur][j] += 1ll * power(h-1, slen[tmp[i-1].r]) * power(h, slen[j-1] - slen[tmp[i-1].r]) % M * sf[cur^1][tmp[i-1].r] % M * (power(h, tmplen[j]) - power(h-1, tmplen[j]) + M) % M;
                f[cur][j] %= M;         
            }
            sf[cur][j] = (sf[cur][j-1] + 1ll * f[cur][j] * inv[j] % M) % M;
        }
        for(int j = tmp[i].r + 1; j <= cnttl; ++j) sf[cur][j] = sf[cur][j-1], f[cur][j] = 0;
    }
    mul = 1ll * mul * sf[cur][cnttl] % M * power(h-1, slen[cnttl]) % M;
}
int main() {
    read(T);
    while(T--) {
        read(n); read(Q); read(m); cntdd = cntval = 0;
        memset(lim, 0, sizeof lim);
        for(int l, r, h, i = 1; i <= Q; ++i) {
            q[i].l = read(l); q[i].r = read(r); q[i].h = read(h);
            dd[++cntdd] = l; dd[++cntdd] = r; val[i] = h;
        }
        sort(q+1, q+Q+1, cmp);
        dd[++cntdd] = 1; dd[++cntdd] = n; cntseg = 0;
        sort(dd+1, dd+cntdd+1); cntdd = unique(dd+1, dd+cntdd+1) - dd - 1;
        for(int i = 1; i <= cntdd; ++i) {
            id[++cntseg] = dd[i]; len[cntseg] = 1;
            if(i != cntdd && dd[i] + 1 < dd[i+1]) id[++cntseg] = dd[i]+1, len[cntseg] = dd[i+1] - dd[i] - 1;
        }
        bool flag = 0;
        for(int l, r, i = 1; i <= Q; ++i) {
            l = (q[i].l = lower_bound(id+1, id+cntseg+1, q[i].l) - id); 
            r = (q[i].r = lower_bound(id+1, id+cntseg+1, q[i].r) - id);
            int maxx = 0; bool has = 0;
            for(int j = l; j <= r; ++j)
                if(!lim[j]) lim[j] = q[i].h, has = 1;
                else maxx = max(maxx, lim[j]);
            if(!has && maxx < q[i].h) flag = 1;
        }
        if(flag) {puts("0"); continue;}
        sort(val+1, val+Q+1); cntval = unique(val+1, val+Q+1) - val - 1;
        //离散化完成
        int ans = 1, i = 1;
        for(int j = 1; j <= cntval; ++j) {
            cnttl = 0, cntt = 0;
            while(i <= Q && q[i].h == val[j]) tmp[++cntt] = q[i++];
            for(int k = 1; k <= cntseg; ++k) if(lim[k] == val[j]) tmplen[pos[k] = ++cnttl] = len[k];
            for(int k = 1; k <= cntt; ++k) {
                for(; lim[tmp[k].l] != val[j]; ++tmp[k].l);
                for(; lim[tmp[k].r] != val[j]; --tmp[k].r);
                tmp[k].l = pos[tmp[k].l]; tmp[k].r = pos[tmp[k].r];
            } trans(ans, val[j]);
        }
        for(int i = 1; i <= cntseg; ++i) if(!lim[i]) ans = 1ll * ans * power(m, len[i]) % M;
        printf("%d\n", ans); 
    }
}