CSP-S 2022 题解

· · 个人记录

T1 假期计划

首先可以用 bfs 在 O(nm) 时间内算出任意两点间的最短路 dis_{i,j},很明显 i,j 两点间可转车 k 次通达当且仅当 dis_{i,j} \le k+1,然后 O(n^2) 枚举景点B,C,那么景点A就是满足 dis_{1,A} \le k+1dis_{A,B} \le k+1 的权值最大的点(不和景点B相同),景点D同理,因为要求A,B,C,D两两不同,所以不能只维护权值最大的点,还要加上次大和次次大点,这个可以 O(n^2) 预处理,总体复杂度 O(nm+n^2)

我的实现多了个排序,复杂度 O(n^2 \log n)
代码:

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
const int MAXN = 2505;

inline ll read() {
    ll x = 0, f = 1; char c = getchar();
    while (c < '0' || c > '9') f = c == '-' ? -1 : 1, c = getchar();
    while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * f;
}

ll s[MAXN], ans;
int n, m, k;
int dis[MAXN][MAXN];
bool vis[MAXN], igl[MAXN];

vector<int> G[MAXN], cdd[MAXN];
queue<int> Q;

int main() {
//  freopen("holiday.in", "r", stdin);
//  freopen("holiday.out", "w", stdout);
    n = read(), m = read(), k = read();
    for (int i = 2; i <= n; i++) s[i] = read();
    for (int i = 1; i <= m; i++) {
        int u = read(), v = read();
        G[u].push_back(v);
        G[v].push_back(u);
    }

    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= n; j++) vis[j] = 0, dis[i][j] = 1e9; dis[i][i] = 0;
        Q.push(i);
        vis[i] = 1;
        while (!Q.empty()) {
            int top = Q.front();
            Q.pop();
            for (int to : G[top]) {
                if (vis[to]) continue;
                dis[i][to] = dis[i][top] + 1;
                vis[to] = 1;
                Q.push(to);
            }
        }
    }

    for (int i = 2; i <= n; i++) {
        for (int j = 2; j <= n; j++)
            if ((i != j) && dis[1][j] <= k + 1 && dis[i][j] <= k + 1)
                cdd[i].push_back(j);
        sort(cdd[i].begin(), cdd[i].end(), [](int a, int b) { return s[a] > s[b]; });
    }

    for (int i = 2; i <= n; i++)
        for (int j = i + 1; j <= n; j++) if (!cdd[i].empty() && !cdd[j].empty() && dis[i][j] <= k + 1) {
            auto p1 = cdd[i].begin(), p2 = cdd[j].begin();
            if (*p1 == j) p1++;
            if (*p2 == i) p2++;
            if (p1 == cdd[i].end() || p2 == cdd[j].end()) continue;

            auto pp = p1;
            while (pp != cdd[i].end() && (*pp == *p2 || *pp == j)) pp++;
            if (pp != cdd[i].end()) ans = max(ans, s[i] + s[j] + s[*pp] + s[*p2]);
            pp = p2;
            while (pp != cdd[j].end() && (*pp == *p1 || *pp == i)) pp++;
            if (pp != cdd[j].end()) ans = max(ans, s[i] + s[j] + s[*p1] + s[*pp]);
        }

    printf("%lld", ans);
    return 0;
}

T2 策略游戏

简单分类讨论,这里把 0 看作正数,第一个选下标的人叫 Alice,另一个叫 Bob。

首先考虑 l1=1,r1=n,l2=1,r2=m 时的情况。

B既有正数也有负数时,很明显因为 Bob 要最小化得分,所以无论 Alice 怎么选数, Bob 都会选一个与 Alice 选的数异号的且绝对值最大数,而 Alice 在保证选的数的符号不变的情况下,绝对值必然越小越好,所以答案即为:

\max(-(\min_{ A_{i \ge 0}} |A_i|)(\max_{A_i < 0} |A_i|), -(\min_{A_i < 0}|A_i|)(\max_{A_i \ge 0} |A_i|))

B只有负数时,若 A有负数,那么 Alice 选负数中绝对值最大数肯定最优,此时 Bob 必然要选绝对值最小的数;否则 Alice 只有正数可选,必然选择绝对值最小的,然后 Bob 选择绝对值最大的。答案即为:

\begin{cases} &(\max_{A_i < 0} |A_i|)(\min_{B_i < 0} |B_i|) &\exists i,A_i < 0\\ &-(\min_{A_i \ge 0} |A_i|)(\max_{B_i < 0} |B_i|) &\forall i, A_i \ge 0 \end{cases}

B只有正数时与只有负数时同理。

加上区间限制,我们只需要查询区间是否全部为正/负数以及区间正/负数的最大/小绝对值即可,用 ST 表或者线段树均可维护,复杂度 O(n \log n)

代码:

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
const int MAXN = 1e5+5;

inline ll read() {
    ll x = 0, f = 1; char c = getchar();
    while (c < '0' || c > '9') f = c == '-' ? -1 : 1, c = getchar();
    while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * f;
}

int n, m, q;
int a[MAXN], b[MAXN];

struct info {
    int sgn;
    int pmx, pmn;//正数中最大/小的绝对值 
    int nmx, nmn;
    info() { sgn = 0, pmn = nmn = 1e9, pmx = nmx = -1; }
};

struct segTree {
    info T[MAXN << 2];

    #define ls (p << 1)
    #define rs (p << 1 | 1)

    inline info merge(info x, info y) {
        info ret;
        ret.sgn = x.sgn | y.sgn;
        ret.pmx = max(x.pmx, y.pmx);
        ret.pmn = min(x.pmn, y.pmn);
        ret.nmx = max(x.nmx, y.nmx);
        ret.nmn = min(x.nmn, y.nmn);
        return ret;
    }

    void build(int p, int l, int r, int *a) {
        if (l == r) {
            if (a[l] >= 0) T[p].sgn = 1, T[p].pmx = T[p].pmn = a[l];
            else T[p].sgn = 2, T[p].nmx = T[p].nmn = -a[l];
            return;
        }
        int mid = (l + r) >> 1;
        build(ls, l, mid, a);
        build(rs, mid + 1, r, a);
        T[p] = merge(T[ls], T[rs]);
    }

    info query(int p, int l, int r, int gl, int gr) {
        if (l >= gl && r <= gr) return T[p];
        int mid = (l + r) >> 1;
        info ret;
        if (mid >= gl) ret = query(ls, l, mid, gl, gr);
        if (mid < gr) ret = merge(ret, query(rs, mid + 1, r, gl, gr));
        return ret;
    }

    #undef ls
    #undef rs
} T[2];

int main() {
//  freopen("game.in", "r", stdin);
//  freopen("game.out", "w", stdout);
    n = read(), m = read(), q = read();
    for (int i = 1; i <= n; i++) a[i] = read();
    for (int i = 1; i <= m; i++) b[i] = read();

    T[0].build(1, 1, n, a);
    T[1].build(1, 1, m, b);

    while (q--) {
        int l1 = read(), r1 = read(), l2 = read(), r2 = read();
        info A = T[0].query(1, 1, n, l1, r1);
        info B = T[1].query(1, 1, m, l2, r2);

        if (B.sgn == 3) {
            ll ans = -1e18;
            if (A.sgn & 1) ans = max(ans, (ll) -A.pmn * B.nmx);
            if (A.sgn & 2) ans = max(ans, (ll) -A.nmn * B.pmx);
            printf("%lld\n", ans);
        } else if (B.sgn == 1) {
            if (A.sgn & 1) {
                printf("%lld\n", (ll) A.pmx * B.pmn);
            } else {
                printf("%lld\n", (ll) -A.nmn * B.pmx);
            }
        } else {
            if (A.sgn & 2) {
                printf("%lld\n", (ll) A.nmx * B.nmn);
            } else {
                printf("%lld\n", (ll) -A.pmn * B.nmx);
            }
        }
    }
    return 0;
}

T3 星战

一直在想根号分治。

首先我们发现可以实现反击其实是可以实现连续穿梭的必要条件,因为只要每个点都有出边,那么一定可以无限走下去(实际上就是个内向基环树森林)。考虑维护一个可重无序集合 S,若一个虫洞可用,那么它的出发节点就将在 S 中多出现一次,那么满足限制当且仅当 S = \{1,2,3,\dots,n\}。每次操作就相当于给定任意一个集合 T,将 T 加入 S 或者从 S 中删去一个 T,这种东西看起来就很不可以低于 O(n^2) 的时间精准维护(至少我没想到),于是可以考虑哈希,给每个节点规定一个随机权值,直接维护当前 S 中所有点的权值和即可,维护非常简单。

我代码的哈希方式是维护编号平方和,欢迎来叉XD。
代码:

#include <bits/stdc++.h>
#define lowbit(x) (x & -x)
#define pb push_back
#define mp make_pair
using namespace std;

typedef long long ll;
const int MAXN = 5e5+5;
const int Mod = 998244353;

#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;

inline int read() {
    int x = 0, f = 1; char c = getchar();
    while (c < '0' || c > '9') f = c == '-' ? -1 : 1, c = getchar();
    while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * f;
}

int n, m, q;
ll stdHsh, curHsh, Hsh[MAXN][2], cursiz, siz[MAXN][2];

int main() {
    // freopen("galaxy.in", "r", stdin);
    // freopen("galaxy.out", "w", stdout);
    n = read(), cursiz = m = read();
    stdHsh = (ll) n * (n + 1) * (2 * n + 1) / 6;
    for (int i = 1; i <= m; i++) {
        int u = read(), v = read();
        curHsh += (ll) u * u;
        Hsh[v][1] += (ll) u * u;
        siz[v][1]++;
    }

    q = read();
    while (q--) {
        int op = read(), u = read(), v;
        if (op == 1) {
            v = read();

            cursiz--;
            siz[v][1]--, siz[v][0]++;

            curHsh -= (ll) u * u;
            Hsh[v][1] -= (ll) u * u;
            Hsh[v][0] += (ll) u * u;
        } else if (op == 2) {
            curHsh -= Hsh[u][1];
            Hsh[u][0] += Hsh[u][1];
            Hsh[u][1] = 0;

            cursiz -= siz[u][1];
            siz[u][0] += siz[u][1];
            siz[u][1] = 0;
        } else if (op == 3) {
            v = read();

            cursiz++;
            siz[v][0]--, siz[v][1]++;

            curHsh += (ll) u * u;
            Hsh[v][0] -= (ll) u * u;
            Hsh[v][1] += (ll) u * u;
        } else {
            curHsh += Hsh[u][0];
            Hsh[u][1] += Hsh[u][0];
            Hsh[u][0] = 0;

            cursiz += siz[u][0];
            siz[u][1] += siz[u][0];
            siz[u][0] = 0;
        }
        puts(cursiz == n && curHsh == stdHsh ? "YES" : "NO");
    }
    return 0;
}

T4 数据传输

这比去年T4简单到不知道哪里去了。

首先 k=1 时就是查询路径权值和,不再赘述。

k=2 时,可以发现此时我们一定是沿着从 uv 的路径走的,不可能走到路径之外的节点,如图:

很明显红色路径花了两步,能走到的最远的点也能从出发点一步走到,所以不可能走到路径之外。
将从 uv 的路径抽出来,设 f_i 表示到达第 i 个点所需的最小代价,转移方程非常好写。

k=3 时,这时我们发现它是可能走到路径外的:

蓝色路径花两步能走到的最远节点最短也需要两步,而红色路径不能,所以可能走到距离路径为 1 的节点。
将从 uv 的路径以及距离该路径为 1 的节点抽出来,我们发现 k=2 时的 dp 状态不适用,不过可以从中得到启发,重新设 f_{i,0/1/2} 表示现在考虑到了路径上第 i 个点,目前所在的点到第 i 个点的距离为 0/1/2,设 mn_{i} 为第 i 个节点的相邻节点中的最小权值,有转移方程:

\begin{cases} f_{i,0} = \min(f_{i-1,0},f_{i-1,1},f_{i-1,2})+val_{i}\\ f_{i,1} = \min(f_{i-1,0},f_{i-1,1}+mn_i)\\ f_{i,2} = f_{i-1,1} \end{cases}

现在单次询问可以做到 O(n),但还不够,注意到 f_i 可以看成是 f_{i-1} 某种意义下的线性组合,与矩阵乘法的形式相似,考虑新定义矩阵乘法:

(AB)_{i,j} = \min_k(A_{i,k}+B_{k,j})

一个非常美妙的性质是,新定义下的矩阵乘法也是满足结合律的:

\begin{aligned} \left [(AB)C \right ]_{i,j} &= \min_k((AB)_{i,k}+C_{k,j})\\ &= \min_k(\min_p(A_{i,p}+B_{p,k})+C_{k,j})\\ &= \min_k(\min_p(A_{i,p} +B_{p,k}+C_{k,j}))\\ &= \min_{k,p}(A_{i,p} +B_{p,k}+C_{k,j})\\ &= \min_p(\min_k(A_{i,p} +B_{p,k}+C_{k,j}))\\ &= \min_p(A_{i,p}+\min_k(B_{p,k}+C_{k,j}))\\ &= \min_p(A_{i,p}+(BC)_{p,j})\\ &= \left [A(BC) \right ]_{i,j} \end{aligned}

i 的转移矩阵即为:

\begin{bmatrix} val_i &0 &+\infty \\ val_i &mn_i &0 \\ val_i &+\infty &+\infty \end{bmatrix}

询问就是求一个路径积,倍增,树剖,点分治,LCT,全局平衡二叉树,TopTree,随便拿个东西就能维护。
代码(倍增):

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
const int MAXN = 2e5+5;
const ll INF = 1e18;

inline ll read() {
    ll x = 0, f = 1; char c = getchar();
    while (c < '0' || c > '9') f = c == '-' ? -1 : 1, c = getchar();
    while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * f;
}

int n, q, K, val[MAXN];
int dep[MAXN], f[MAXN][18];

vector<int> T[MAXN];

struct Matrix {
    ll ele[3][3];
    ll &operator ()(int x, int y) { return ele[x][y]; }
    Matrix () { memset(ele, 0x3f, sizeof(ele)); }
} F[MAXN], up[MAXN][18], dwn[MAXN][18];

Matrix operator *(Matrix x, Matrix y) {
    Matrix ret;
    for (int i = 0; i < K; i++)
        for (int j = 0; j < K; j++)
            for (int k = 0; k < K; k++)
                ret(i, j) = min(ret(i, j), x(i, k) + y(k, j));
    return ret;
}

void dfs1(int x, int fa) {
    dep[x] = dep[fa] + 1;
    f[x][0] = fa;
    up[x][0] = F[fa];
    dwn[x][0] = F[x];
    for (int i = 1; i < 18; i++) {
        f[x][i] = f[f[x][i - 1]][i - 1];
        up[x][i] = up[x][i - 1] * up[f[x][i - 1]][i - 1];
        dwn[x][i] = dwn[f[x][i - 1]][i - 1] * dwn[x][i - 1];
    }

    for (int son : T[x]) {
        if (son == fa) continue;
        dfs1(son, x);
    }
}

inline int LCA(int u, int v) {
    if (dep[u] < dep[v]) swap(u, v);

    int h = dep[u] - dep[v];
    for (int i = 0; i < 18; i++)
        if (h >> i & 1)
            u = f[u][i];
    if (u == v) return u;

    for (int i = 17; ~i; i--)
        if (f[u][i] ^ f[v][i])
            u = f[u][i], v = f[v][i];

    return f[u][0];
}

int main() {
//  freopen("transmit.in", "r", stdin);
//  freopen("transmit.out", "w", stdout);
    n = read(), q = read(), K = read();
    for (int i = 1; i <= n; i++) F[i](0, 0) = F[i](1, 0) = F[i](2, 0) = val[i] = read(), F[i](0, 1) = F[i](1, 2) = 0;

    for (int i = 1; i < n; i++) {
        int u = read(), v = read();
        T[u].push_back(v);
        T[v].push_back(u);
    }

    if (K == 3)
    for (int i = 1; i <= n; i++)
        for (int to : T[i])
            F[i](1, 1) = min(F[i](1, 1), (ll) val[to]);

    dfs1(1, 0);

    while (q--) {
        int u = read(), v = read();
        Matrix tmp, tmp1;
        tmp(0, 0) = val[u];
        for (int i = 0; i < K; i++) tmp1(i, i) = 0;

        int lca = LCA(u, v), h = dep[u] - dep[lca];
        //printf("-%d %d %d-\n", u, v, lca);
        for (int i = 0; i < 18; i++)
            if (h >> i & 1)
                tmp = tmp * up[u][i], u = f[u][i];
        //tmp.print();

        h = dep[v] - dep[lca];
        for (int i = 0; i < 18; i++)
            if (h >> i & 1)
                tmp1 = dwn[v][i] * tmp1, v = f[v][i];

        printf("%lld\n", (tmp * tmp1)(0, 0));
    }
    return 0;
}