22.11.17

· · 个人记录

A

题意:

问你有多少个长度为 n 的序列,满足每一个数都小于等于 m ,且存在一个和为 p 的子段,紧接着一个和为 q 的子段,再紧接着一个和为 z 的子段。答案对 998244353 取模。

1\leq n\leq50 1\leq m\leq10^9 \max(p,q,r)\leq6

题解:

状压 dp 解决,考虑先把原题容斥掉,可以变为问有多少个序列不满足条件,最终的答案就是总方案数-这个值。总方案数很好计算就是 m^n ,问题就是怎么计算有多少个不满足条件的序列。我们仔细观察会发现一个很奇怪的限制, p,q,r 都非常小,甚至考试的时候专门把他们的 \max\leq7 变为 \max\leq6 ,我们把这里作为我们的突破口,我们一位一位的来构造,因为我们想要没有满足条件的序列,那么我们就是对于任何时刻,都不会同时存在三个 r,q+r,p+q+r 的后缀和。那么我们的 dp 状态就很好定义了, dp_{i,mask} 表示的是前i个位置,后缀和的集合为 mask,因为 p,q,r 都非常的小,总共的值才 18 所有我们只用维护 2^{18}\times n 种状态就好了。转移的收,我们使用二进制下的左移就可以了。

复杂度分析:

时间复杂度:O(n\times 2^{p+q+r}\times \max(p,q,r))

空间复杂度:O(n\times 2^{p+q+r})

代码:

#include <bits/stdc++.h>
#define int long long
using namespace std;
int n, m, p, q, r, dp[55][1001001], maxv, sum, pw = 1;
const int mod = 998244353;
signed main() {
    freopen("a.in", "r", stdin);
    freopen("a.out", "w", stdout);
    scanf("%lld%lld%lld%lld%lld", &n, &m, &p, &q, &r);
    maxv = max(p, max(q, r));
    dp[0][0] = 1;
    for (int i = 1; i <= n; i++) pw = pw * m % mod;
    for (int i = 1; i <= n; i++)
        for (int j = 0; j < (1ll << (p + q + r)); j++)
            if (dp[i - 1][j]) {
                if (m > maxv)
                    (dp[i][0] += dp[i - 1][j] * (m - maxv) % mod) %= mod;
                for (int k = 1; k <= min(maxv, m); k++) {
                    int num = (j << k) & ((1ll << (p + q + r)) - 1) | (1ll << (k - 1));
                    if (((num >> (r - 1)) & 1) && ((num >> (q + r - 1)) & 1) &&
                        ((num >> (p + q + r - 1)) & 1))
                        continue;
                    (dp[i][num] += dp[i - 1][j]) %= mod;
                }
            }
    for (int j = 0; j < (1ll << (p + q + r)); j++) (sum += dp[n][j]) %= mod;
    cout << (pw - sum + mod) % mod << endl;
    return 0;
}

B

题意:

求出 (b-1)\times b^n \bmod m

1\leq n\leq10^{10^6} 2\leq b\leq10^{10^6} 1\leq m\leq10^9

题解:

我们先把 b-1b \bmod m 的值全部算出来,因为 m\tt int 范围的,而 b 不是,所以需要用到高精度,但是高精度对 \tt int 就好了,并不用高精度对高精度,所需的时间复杂度就是 O(lgb)的。之后考虑优化我们的指数 n ,因为普通的快速幂是 \log_2 的,很明显过不去,所以考虑优化,这里可以使用拓展欧拉定理,但是需要特判 n 是不是大于等于 \phi(m) 的。

复杂度分析:

时间复杂度:O(lgb)

空间复杂度:O(lgb)

代码:

#include <bits/stdc++.h>
#define int long long
using namespace std;
int n[1001001], b[1001001], mod, c[1001001], t[1001001];
char B[1001001];
void mi(int *a, int *b) {
    for (int i = 0; i <= 1e6 + 10; i++) c[i] = 0;
    c[0] = max(a[0], b[0]);
    for (int i = 1; i <= a[0]; i++) {
        c[i] += (a[i] - b[i]);
        if (c[i] < 0) {
            c[i + 1]--;
            c[i] += 10;
        }
    }
    while (c[c[0]] == 0 && c[0] > 1) c[0]--;
}
int mo(int *a, int b) {
    int tmp = 0;
    for (int i = a[0]; i; i--) tmp = (tmp * 10 % b + a[i]) % b;
    return tmp;
}
int div(int *a, int b) {
    int tmp = 0, ans = 0;
    for (int i = a[0]; i; i--) {
        tmp = tmp * 10 + a[i];
        ans = ans * 10 + tmp / b;
        tmp %= b;
    }
    return ans;
}
int get(int x) {
    int ans = x;
    for (int i = 2; i * i <= x; i++)
        if (x % i == 0) {
            ans -= ans / i;
            while (x % i == 0) x /= i;
        }
    if (x > 1)
        ans -= ans / x;
    return ans;
}
int pow_mod(int a, int b) {
    int ans = 1;
    while (b) {
        if (b & 1)
            ans = ans * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return ans;
}
signed main() {
    freopen("b.in", "r", stdin);
    freopen("b.out", "w", stdout);
    char ch = getchar();
    while (ch < '0' || ch > '9') ch = getchar();
    while (ch >= '0' && ch <= '9') {
        B[n[0]++] = ch;
        ch = getchar();
    }
    for (int i = 1; i <= n[0]; i++) n[i] = (B[n[0] - i] - '0');
    ch = getchar();
    while (ch < '0' || ch > '9') ch = getchar();
    while (ch >= '0' && ch <= '9') {
        B[b[0]++] = ch;
        ch = getchar();
    }
    for (int i = 1; i <= b[0]; i++) b[i] = (B[b[0] - i] - '0');
    cin >> mod;
    t[0] = 1;
    t[1] = 1;
    mi(b, t);
    int t1 = mo(c, mod);
    int t2 = mo(b, mod);
    mi(n, t);
    int t3 = mo(c, get(mod));
    if (div(c, get(mod)) > 0)
        t3 += get(mod);
    cout << t1 * pow_mod(t2, t3) % mod;
    return 0;
}

C

题意:

给你个带权无向图,你可以任意交换两条边的边权,问你进行任意次操作之后,从 uv ,再到 w 的最小边权之和是多少。

1\leq n,m\leq2\times 10^5

题解:

这题可以变化为从 v 开始,到 uw 的经过所有边的最小边权和。因为我们可以任意交换边权,所以我们一定是从最小的开始用。而我们从 v 开始,一定是从某一个节点 m 开始分叉的,所以我们就去枚举这个分叉点 m 找到一个点,使得它到给出的三个点经过的边数最小即可,然后选取最小的几个即可。

复杂度分析:

时间复杂度:O(n)

空间复杂度:O(n)

代码:

#include <bits/stdc++.h>
#define int long long
using namespace std;
int n, m, u, v, x, y, z, e[1001001], dis[3][1001001], s[1001001];
queue<int> q;
vector<int> g[1001001];
void bfs(int id, int u) {
    q.push(u);
    dis[id][u] = 0;
    while (!q.empty()) {
        int u = q.front();
        q.pop();
        for (int v : g[u])
            if (dis[id][v] == -1) {
                dis[id][v] = dis[id][u] + 1;
                q.push(v);
            }
    }
}
signed main() {
    freopen("c.in", "r", stdin);
    freopen("c.out", "w", stdout);
    scanf("%lld%lld%lld%lld%lld", &n, &m, &x, &y, &z);
    for (int i = 1; i <= m; i++) {
        scanf("%lld%lld%lld", &u, &v, &e[i]);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    memset(dis, -1, sizeof(dis));
    bfs(0, x);
    bfs(1, y);
    bfs(2, z);
    sort(e + 1, e + m + 1);
    for (int i = 1; i <= m; i++) s[i] = s[i - 1] + e[i];
    int ans = 1ll << 50;
    for (int i = 1; i <= n; i++)
        if (dis[0][i] != -1 && dis[1][i] != -1 && dis[2][i] != -1 && dis[0][i] + dis[1][i] + dis[2][i] <= m)
            ans = min(ans, s[dis[0][i] + dis[1][i] + dis[2][i]] + s[dis[1][i]]);
    cout << ans;
    return 0;
}

D

题意:

给你 n 条水平或者垂直的线段,问你最少需要台笔多少次,才能将所有的线段全部遍历 m

1\leq n\leq2\times 10^3

题解:

小学的时候就学过一笔画问题,这题很类似。我们还是从简单的问题开始考虑, m=1。不同连通块肯定是要用一次抬笔的,对于一个连通块内,计算一下有多少个交点(或者顶点)有奇数个度,我们一次抬笔只能解决两个,所以我们统计一下有多少个,然后计算贡献就好了。对于 m\ne1 的情况,因为每个度都要便利 m 次,所以把 \text{每个点的度}\times m 就好了。

复杂度分析:

时间复杂度:O(n^2)

空间复杂度:O(n^2)

代码:

#include <bits/stdc++.h>
#define int long long
using namespace std;
int n, m, w, tot, hs[100100], fa[100100], c[100100][5], mc;
struct node {
    int xa, ya, xb, yb;
} s[100100], rs[100100];
vector<pair<int, int>> sx[100100], sy[100100];
int fd(int x) { return x == fa[x] ? x : fa[x] = fd(fa[x]); }
void mer(int x, int y) {
    x = fd(x), y = fd(y);
    if (x != y) {
        mc++;
        fa[x] = y;
        for (int i = 1; i <= 4; i++) c[y][i] += c[x][i];
    }
}
signed main() {
    freopen("d.in", "r", stdin);
    freopen("d.out", "w", stdout);
    scanf("%lld%lld", &n, &m);
    for (int i = 1; i <= n; i++) {
        scanf("%lld%lld%lld%lld", &s[i].xa, &s[i].ya, &s[i].xb, &s[i].yb);
        if (s[i].xa > s[i].xb)
            swap(s[i].xa, s[i].xb);
        if (s[i].ya > s[i].yb)
            swap(s[i].ya, s[i].yb);
        hs[++w] = s[i].xa;
        hs[++w] = s[i].ya;
        hs[++w] = s[i].xb;
        hs[++w] = s[i].yb;
    }
    sort(hs + 1, hs + w + 1);
    w = unique(hs + 1, hs + w + 1) - hs - 1;
    for (int i = 1; i <= n; i++) {
        s[i].xa = lower_bound(hs + 1, hs + w + 1, s[i].xa) - hs;
        s[i].xb = lower_bound(hs + 1, hs + w + 1, s[i].xb) - hs;
        s[i].ya = lower_bound(hs + 1, hs + w + 1, s[i].ya) - hs;
        s[i].yb = lower_bound(hs + 1, hs + w + 1, s[i].yb) - hs;
        if (s[i].xa == s[i].xb)
            sy[s[i].xa].push_back({ s[i].ya, s[i].yb });
        else
            sx[s[i].ya].push_back({ s[i].xa, s[i].xb });
    }
    for (int i = 1; i <= w; i++)
        if (!sx[i].empty()) {
            sort(sx[i].begin(), sx[i].end());
            int lm = 0, rm = 0;
            for (pair<int, int> x : sx[i]) {
                if (x.first > rm) {
                    if (lm)
                        rs[++tot] = node{ lm, i, rm, i };
                    lm = x.first;
                    rm = x.second;
                } else
                    rm = max(rm, x.second);
            }
            if (lm)
                rs[++tot] = node{ lm, i, rm, i };
        }
    for (int i = 1; i <= w; i++)
        if (!sy[i].empty()) {
            sort(sy[i].begin(), sy[i].end());
            int lm = 0, rm = 0;
            for (pair<int, int> x : sy[i]) {
                if (x.first > rm) {
                    if (lm)
                        rs[++tot] = node{ i, lm, i, rm };
                    lm = x.first;
                    rm = x.second;
                } else
                    rm = max(rm, x.second);
            }
            if (lm)
                rs[++tot] = node{ i, lm, i, rm };
        }
    for (int i = 1; i <= tot; i++) {
        fa[i] = i;
        c[i][1] = 2;
    }
    for (int i = 1; i <= tot; i++)
        if (rs[i].ya == rs[i].yb) {
            for (int j = 1; j <= tot; j++)
                if (rs[j].xa == rs[j].xb) {
                    int fy = fd(j);
                    if ((rs[i].xa == rs[j].xa || rs[i].xb == rs[j].xa) &&
                        (rs[i].ya == rs[j].ya || rs[i].ya == rs[j].yb)) {
                        mer(i, j);
                        c[fy][1] -= 2;
                        c[fy][2]++;
                    } else if (((rs[i].xa == rs[j].xa || rs[i].xb == rs[j].xa) && rs[i].ya >= rs[j].ya &&
                                rs[i].ya <= rs[j].yb) ||
                               ((rs[j].ya == rs[i].ya || rs[j].yb == rs[i].ya) && rs[j].xa >= rs[i].xa &&
                                rs[j].xa <= rs[i].xb)) {
                        mer(i, j);
                        c[fy][1]--;
                        c[fy][3]++;
                    } else if (rs[j].xa >= rs[i].xa && rs[j].xa <= rs[i].xb && rs[i].ya >= rs[j].ya &&
                               rs[i].ya <= rs[j].yb) {
                        mer(i, j);
                        c[fy][4]++;
                    }
                }
        }
    int ans = 0, cc = 0;
    for (int i = 1; i <= tot; i++)
        if (fa[i] == i)
            ans += max(1ll, m * (c[i][1] + c[i][2] * 2 + c[i][3] * 3 + c[i][4] * 4) / 2 - m / 2 * c[i][1] -
                                m * c[i][2] - m * 3 / 2 * c[i][3] - m * 2 * c[i][4]);
    cout << ans - 1;
    return 0;
}