P14808 [CCPC 2024 哈尔滨站] 子序列计数
EuphoricStar · · 题解
令
我们发现子序列的 DP 可以用矩阵描述。与其去算子序列,我们不妨去算一个广义的信息:把
为了直观理解这个问题,我们可以把所有位置排成一个
发现我们走的顺序是先从第
那么我们把每一列的矩阵乘积算出来,就可以递归到
同时若初始的
要求出对应的段以及矩阵乘积是简单的,我们开一棵线段树维护新的每一段的矩阵乘积,那么原本
有个小问题,若
我们做一个修正:若
总时间复杂度
:::info[代码]
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
using ll = long long;
using ull = unsigned long long;
using db = double;
using ldb = long double;
using pii = pair<int, int>;
using pll = pair<ll, ll>;
const int maxn = 2020;
const int mod = 998244353;
inline void fix(int &x) {
x += ((x >> 31) & mod);
}
int n, m, K, L, a[12], b[maxn << 1];
void exgcd(int a, int b, int &x, int &y) {
if (!b) {
x = 1;
y = 0;
return;
}
exgcd(b, a % b, y, x);
y -= a / b * x;
}
inline int inv(int a, int p) {
int x, y;
exgcd(a, p, x, y);
return (x % p + p) % p;
}
struct mat {
int a[12][12];
mat() {
mems(a, 0);
}
} I;
inline mat operator * (const mat &a, const mat &b) {
mat res;
for (int i = 0; i <= m; ++i) {
for (int k = 0; k <= m; ++k) {
if (!a.a[i][k]) {
continue;
}
for (int j = 0; j <= m; ++j) {
if (!b.a[k][j]) {
continue;
}
fix(res.a[i][j] += 1ULL * a.a[i][k] * b.a[k][j] % mod - mod);
}
}
}
return res;
}
inline mat qpow(mat a, int p) {
mat res = I;
while (p) {
if (p & 1) {
res = res * a;
}
a = a * a;
p >>= 1;
}
return res;
}
namespace SGT {
mat a[maxn << 2];
bool vis[maxn << 2];
void build(int rt, int l, int r) {
a[rt] = I;
vis[rt] = 0;
if (l == r) {
return;
}
int mid = (l + r) >> 1;
build(rt << 1, l, mid);
build(rt << 1 | 1, mid + 1, r);
}
inline void pushtag(int x, const mat &y) {
a[x] = a[x] * y;
vis[x] = 1;
}
inline void pushdown(int x) {
if (!vis[x]) {
return;
}
pushtag(x << 1, a[x]);
pushtag(x << 1 | 1, a[x]);
a[x] = I;
vis[x] = 0;
}
void update(int rt, int l, int r, int ql, int qr, const mat &x) {
if (ql > qr) {
return;
}
if (ql <= l && r <= qr) {
pushtag(rt, x);
return;
}
pushdown(rt);
int mid = (l + r) >> 1;
if (ql <= mid) {
update(rt << 1, l, mid, ql, qr, x);
}
if (qr > mid) {
update(rt << 1 | 1, mid + 1, r, ql, qr, x);
}
}
void dfs(int rt, int l, int r, vector<pair<int, mat>> &vc) {
if (l == r) {
vc.pb(b[l + 1] - b[l], a[rt]);
return;
}
pushdown(rt);
int mid = (l + r) >> 1;
dfs(rt << 1, l, mid, vc);
dfs(rt << 1 | 1, mid + 1, r, vc);
}
}
mat work(int l, int k, vector<pair<int, mat>> vc) {
if (l == 1) {
return vc[0].scd;
}
if (k * 2 > l) {
mat A = vc[0].scd;
reverse(vc.begin(), vc.end());
if (vc.back().fst == 1) {
vc.pop_back();
} else {
--vc.back().fst;
}
vc.insert(vc.begin(), mkp(1, A));
return work(l, l - k, vc);
}
int n = (int)vc.size();
b[1] = 0;
for (int i = 0; i < n; ++i) {
b[i + 2] = (b[i + 1] + vc[i].fst) % k;
}
sort(b + 1, b + n + 2);
n = unique(b + 1, b + n + 2) - b - 1;
b[n + 1] = k;
int s = 0;
SGT::build(1, 1, n);
for (auto p : vc) {
int l = s, r = s + p.fst - 1;
s += p.fst;
int x = lower_bound(b + 1, b + n + 1, l % k) - b, y = lower_bound(b + 1, b + n + 1, (r + 1) % k) - b - 1;
if (l / k == (r + 1) / k) {
SGT::update(1, 1, n, x, y, p.scd);
} else {
SGT::update(1, 1, n, x, n, p.scd);
SGT::update(1, 1, n, 1, n, qpow(p.scd, (r + 1) / k - l / k - 1));
SGT::update(1, 1, n, 1, y, p.scd);
}
}
vector<pair<int, mat>> nv;
SGT::dfs(1, 1, n, nv);
return work(k, k - l % k, nv);
}
void solve() {
scanf("%d%d%d%d", &n, &m, &K, &L);
for (int i = 1; i <= m; ++i) {
scanf("%d", &a[i]);
}
K = inv(K, L);
for (int i = 0; i <= m; ++i) {
I.a[i][i] = 1;
}
vector<pair<int, mat>> vc;
for (int i = 1, x, y; i <= n; ++i) {
scanf("%d%d", &x, &y);
mat A = I;
for (int j = 1; j <= m; ++j) {
if (y == a[j]) {
A.a[j - 1][j] = 1;
}
}
vc.pb(x, A);
}
printf("%d\n", work(L, K, vc).a[0][m]);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}
:::