P1054 [NOIP2005 提高组] 等价表达式 题解

· · 题解

看了一圈题解区只有我用了暴力展开 /jy

实际上并不难,还有点水,也不难调。

相信各位应当都会中缀表达式的计算:我们维护一个符号栈和一个数据栈,然后分类讨论根据优先级计算即可。这一点不管是别的题目还是题解区已经讲的很好了,OI Wiki 上的讲解也不错。

这篇题解的重点是把普通整数上的操作扩展到多项式。

由于只有一个字母,我们可以维护系数向量 v,其中 v_k 代表 a^k 的系数,常数就是 a^0 的系数。于是,我们就可以用 O(k) 时间做到加减操作。

然后考虑乘和乘方。先考虑朴素的 O(n^2) 乘法。(通过手工模拟可得),代码为:

for (int i = 0; i < K; i++)
    for (int j = 0; i + j < K; j++) 
        r.k[i + j] += k[i] * x.k[j];

如果要求更高的速度,可以用 FFT 优化。但是本题数据范围比较小,所以不用。

乘方可以用经典的快速幂,复杂度是 O(k^2\log n) 的,或者是 O(k \log^2 n)(FFT 优化)。

接下来就是代码时间了,其实思路是非常好想的。

但是有几点注意:

接下来放看起来还是很不错的代码:

#include <iostream>
#include <stack>
#include <cstring>
#define siz(x) static_cast<int> ((x).size ())
#define fi first
#define se second
#define int long long
using namespace std;
const int N = 30, K = 500, P = 998244353;

class poly {
public:
    int k[K];

    poly () { memset (k, 0, sizeof k); }
    poly (const poly &x) { memcpy (k, x.k, sizeof k); }
    poly &operator = (const poly &x) { memcpy (k, x.k, sizeof k); return *this; }
    poly (int x, bool flag = false) { memset (k, 0, sizeof k), k[flag] = x; }
    poly operator + (const poly &x) {
        poly r;
        for (int i = 0; i < K; i++) r.k[i] = (k[i] + x.k[i]) % P;
        return r;
    }

    poly operator - (const poly &x) {
        poly r;
        for (int i = 0; i < K; i++) r.k[i] = (k[i] + P - x.k[i]) % P;
        return r;
    }

    poly operator * (const poly &x) {
        poly r;
        for (int i = 0; i < K; i++) for (int j = 0; i + j < K; j++) (r.k[i + j] += k[i] * x.k[j] % P) %= P;
        return r;
    }

    poly operator ^ (int x) {
        poly r = *this, p = *this;
        --x;
        while (x) {
            if (x & 1) r = r * p;
            x >>= 1, p = p * p;
        }
        return r;
    }

    bool operator == (const poly &x) {
        for (int i = 0; i < K; i++) if (k[i] != x.k[i]) return false;
        return true;
    }

    friend ostream &operator << (ostream &out, const poly &x) {
        for (int i = K - 1; i >= 2; i--) if (x.k[i]) {
            if (x.k[i] == 1) out << "a^" << i << " + ";
            else if (x.k[i] == -1) out << "-a^" << i << " + ";
            else out << x.k[i] << "a^" << i << " + ";
        }
        if (x.k[1] == 1) out << "a + ";
        else if (x.k[1] == -1) out << "-a + ";
        else if (x.k[1]) out << x.k[1] << "a + ";
        out << x.k[0];
        return out;
    }
};

stack<char> op;
stack<poly> q;

void ins (char c)
{
    poly y, x;

    if (c == '(') return;
    if (q.empty ()) cout << "fuck off", exit (0);
    y = q.top (), q.pop ();
    if (q.empty ()) cout << "fuck off", exit (0);
    x = q.top (), q.pop ();
    switch (c) {
    case '+': q.push (x + y); break;
    case '-': q.push (x - y); break;
    case '*': q.push (x * y); break;
    case '^': q.push (x ^ y.k[0]); break;
    default: break;
    }
}

int prio (char c)
{
    switch (c) {
    case '(': return 0;
    case '+': case '-': return 1;
    case '*': return 2;
    case '^': return 3;
    default: return -1;
    }
}

bool space (char c) { return c == ' ' || c == '\n' || c == '\r'; }

poly expand (void)
{
    int level = 0;
    char c;

    while (!q.empty ()) q.pop ();
    while (space (c = getchar ()));
    do {
        if (isdigit (c)) {
            int x = 0;
            do x = (x * 10 % P + c - '0') % P; while (isdigit (c = getchar ()));
            while (space (c) && c != '\n') c = getchar ();
            q.push (poly (x));
            if (c == '\n') break;
            else continue;
        }

        switch (c) {
        case 'a': q.push (poly (1, true)); break;
        case '(': op.push ('('), level++; break;
        case ')':
            if (--level < 0) break;
            while (!op.empty () && op.top () != '(') ins (op.top ()), op.pop ();
            op.pop (); break;
        default:
            while (!op.empty () && prio (op.top ()) >= prio (c)) ins (op.top ()), op.pop ();
            op.push (c); break;
        }
        while (space (c = getchar ()) && c != '\n');
    } while (c != '\n');

    while (!op.empty ()) ins (op.top ()), op.pop ();
    return q.top ();
}

int read (void)
{
    int res = 0;
    char c;
    while (!isdigit (c = getchar ()));
    do res = res * 10 + c - '0'; while (isdigit (c = getchar ()));
    return res;
}

signed main (void)
{
    int n;
    poly res;

    res = expand ();
    n = read ();

    for (int i = 1; i <= n; i++) if (expand () == res) putchar ("ABCDEFGHIJKLMNOPQRSTUVWXYZ"[i - 1]);
    putchar ('\n');
    return 0;
}