P1763 埃及分数

· · 题解

题意

给定一个分数 \frac{a}{b},并且设 \frac{a}{b} = \frac{1}{p_1} + \frac{1}{p_2} + \cdots + \frac{1}{p_n}(\frac{1}{p_1} > \frac{1}{p_2} > \cdots > \frac{1}{p_n}),求出在满足 n 尽量小的同时 \frac{1}{p_n} 尽量大,即 p_n 尽量小。

分析

第一眼想到的应该是广搜,但是明显广搜空间会爆,而普通的深搜时间也无法保证,因为 1 < a < b < 1000,所以我们有一种新的方法:迭代加深。

迭代加深其实是对深搜的一种优化。当一道题明显只能用搜索,但是普通深搜和广搜都无法通过时我们就可以考虑迭代加深。

众所周知,深搜是一个路径一直搜到最后才停下,但是假设有这样的搜索树:

那么明显从根节点开始遍历左子树时会遍历很远,但是目标点却很近,所以我们可以在每次搜索时定义一个最大搜索长度,超过长度就停止搜索。

那么对于这道题还有一个问题就是枚举的上界和下界,这里是解题的关键。

假设目前枚举的是第 dep 个分数的分母,而之前所有分数的和是 \frac{d}{c},那么我们需要求出第 dep 个分数的分母的上下界,而分子永远是 1

我们设第 dep 个分数的分母为 p,显然可得 \frac{d}{c} + \frac{1}{p} \le \frac{a}{b},接着得 \frac{1}{p} \le \frac{a}{b} - \frac{d}{c},设 \frac{a}{b} - \frac{d}{c} = \frac{g}{e},那么 \frac{1}{p} \le \frac{g}{e},显然 p \times g \ge e,得 p \ge e \div g

我们同样设第 dep 个分数的分母为 p,总共分数个数 n 个,因为分母都是递增的,所以从 dep + 1n 得每个分母都 \ge p,虽然不会取等于,但是考虑 = 更加好算。那么得 \frac{d}{c} + \frac{1}{p} \times (n - dep + 1) \ge \frac{a}{b},在循环中每次判断是否成立即可。因为假设 p = x 已经不符合要求,那么 p > x 显然更加不符合要求。

代码:

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <utility>
#include <cmath>
#include <cstring>
using namespace std;

#define ll unsigned long long

static char buf[1000000], * p1 = buf, * p2 = buf;
#define getchar() p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1000000, stdin), p1 == p2) ? EOF : *p1++

inline ll read()
{
    register char ch = getchar();
    register ll x = 0;
    while (ch < '0' || ch > '9') ch = getchar();
    while (ch >= '0' && ch <= '9')
    {
        x = (x << 1) + (x << 3) + (ch ^ 48);
        ch = getchar();
    }
    return x;
}

inline void write(ll x)
{
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}

inline int gcd(register int a, register int b) // 计算最大公因数
{
    return (b == 0 ? a : gcd(b, a % b));
}

#define lcm(a, b) (a * b / gcd(a, b)) // 最小公倍数

inline pair<ll, ll> add(const pair<ll, ll>& a, const pair<ll, ll>& b) // a.first 表示第一个分数的分子,a.second 表示第一个分数的分母,b.first 表示第二个分数的分子,b.second 表示第二个分数的分母
{
    register ll newfm = lcm(a.second, b.second);
    register ll newfz = newfm / a.second * a.first + newfm / b.second * b.first;
    register ll g = gcd(newfm, newfz);
    newfm /= g;
    newfz /= g;
    return make_pair(newfz, newfm);
}

inline pair<ll, ll> Minus(register const pair<ll, ll>& a, register const pair<ll, ll>& b)
{
    register ll newfm = lcm(a.second, b.second);
    register ll newfz = newfm / a.second * a.first - newfm / b.second * b.first;
    register ll g = gcd(newfm, newfz);
    newfm /= g;
    newfz /= g;
    return make_pair(newfz, newfm);
}

inline bool bigger(register const pair<ll, ll>& a, register const pair<ll, ll>& b) // 返回第一个分数 >= 第二个分数
{
    register ll newfm = lcm(a.second, b.second);
    register ll nfz1 = newfm / a.second * a.first;
    register ll nfz2 = newfm / b.second * b.first;
    return nfz1 >= nfz2;
}

inline bool big(register const pair<ll, ll>& a, register const pair<ll, ll>& b) // 返回第一个分数 >= 第二个分数
{
    register ll newfm = lcm(a.second, b.second);
    register ll nfz1 = newfm / a.second * a.first;
    register ll nfz2 = newfm / b.second * b.first;
    return nfz1 > nfz2;
}

const int N = 1e5 + 5;

ll a, b, tdep;
pair<ll, ll> k;
ll arr[N], res[N];

inline void dfs(register ll dep, register pair<ll, ll> p)
{
    if (dep == tdep)
    {
        register pair<ll, ll> q = Minus(k, p);
        if (q.first != 1) return;
        arr[dep] = q.second;
        if (!res[tdep] || res[tdep] > arr[tdep])
        {
            memcpy(res, arr, sizeof(res));
        }
        return;
    }
    register pair<ll, ll> t = Minus(make_pair(a, b), make_pair(p.first, p.second));
    register ll minn = max((ll)arr[dep - 1] + 1, (ll)ceil(t.second * 1.0 / t.first));
    register int l = tdep - dep + 1;
    for (register int i = minn; big(add(p, make_pair(l, i)), k); ++i)
    {
        arr[dep] = i;
        dfs(dep + 1, add(p, make_pair(1, i)));
    }
}

int main()
{
    a = read(), b = read();
    register int tmp = gcd(a, b);
    a /= tmp;
    b /= tmp;
    if (a == 1)
    {
        write(b);
        return 0;
    }
    k.first = a;
    k.second = b;
    register int i, j;
    for (i = 2; ; ++i)
    {
        memset(res, 0, sizeof(res));
        tdep = i;
        dfs(1, make_pair(0, 1));
        if (res[i])
        {
            for (j = 1; j <= i; ++j)
            {
                write(res[j]);
                putchar(' ');
            }
            return 0;
        }
    }
    return 0;
}

最终用时:

C++14 (GCC 9) O2: 941ms

C++14 (GCC 9): 1.12s