惰性求值、无穷流与发生的魔法(C++ 版)

· · 个人记录

#include <iostream>
#include <functional>
#include <memory>
#include <optional>
#include <utility>
#include <vector>
#include <numeric>
#include <cmath>

// =================================================================
// Part 1: 实现 Scheme 的惰性求值核心 (delay & force)
//
// Scheme: (define-syntax delay ...) (define (force ...))
// C++:    一个 Lazy<T> 类
// =================================================================

template<typename T>
class Lazy {
private:
    // 用于存储延迟计算的函数
    std::function<T()> computation;
    // 用于缓存计算结果,std::optional 天然适合表示“已计算”或“未计算”
    mutable std::optional<T> cache;

public:
    // 构造函数,接受一个 lambda 作为延迟的操作
    Lazy(std::function<T()> func) : computation(std::move(func)) {}

    // force 操作:获取值。如果是第一次,则计算并缓存
    T& operator()() const {
        if (!cache) {
            cache = computation();
        }
        return *cache;
    }
};

// 辅助函数,使得创建 Lazy 对象更像 Scheme 的 (delay ...)
template<typename F>
auto delay(F&& func) {
    // T 是 func() 的返回类型
    using T = decltype(func());
    return Lazy<T>(std::forward<F>(func));
}

// =================================================================
// Part 2: 实现惰性流 (Stream)
//
// Scheme: (stream-cons a b), (stream-car a), (stream-cdr a)
// C++:    一个 Stream<T> 结构体和几个辅助函数
//
// C++ 中需要用指针(这里是智能指针)来处理这种递归数据结构
// =================================================================

template<typename T>
struct StreamNode; // 前向声明

// Stream 本身其实是一个指向节点的智能指针,便于管理内存
template<typename T>
using Stream = std::shared_ptr<StreamNode<T>>;

template<typename T>
struct StreamNode {
    T head;
    // tail 是一个“延迟计算的、指向下一个节点的流”
    Lazy<Stream<T>> tail;
};

// (stream-cons a b)
template<typename T, typename F>
Stream<T> stream_cons(T head, F&& tail_func) {
    return std::make_shared<StreamNode<T>>(StreamNode<T>{
            head,
            delay([f = std::forward<F>(tail_func)]() { return f(); })
    });
}

// (stream-car x)
template<typename T>
T stream_car(const Stream<T>& s) {
    return s->head;
}

// (stream-cdr x)
template<typename T>
Stream<T> stream_cdr(const Stream<T>& s) {
    return (s->tail)(); // 调用 Lazy<T>::operator()() 来 force
}

// =================================================================
// Part 3: 创建和操作流
//
// Scheme: (int-from n), (get a pos), (take a len)
// =================================================================

// (int-from n) -> 创建一个无限自然数流
Stream<long long> int_from(long long n) {
    return stream_cons<long long>(n, [n]() {
        return int_from(n + 1);
    });
}

// (get a pos) -> 获取流中指定位置的元素
template<typename T>
T get(Stream<T> s, int pos) {
    if (pos == 0) {
        return stream_car(s);
    }
    return get(stream_cdr(s), pos - 1);
}

// (take a len) -> 从流中取出前 len 个元素,放入一个 vector
template<typename T>
std::vector<T> take(Stream<T> s, int len) {
    if (len == 0 || !s) {
        return {};
    }
    std::vector<T> result;
    result.push_back(stream_car(s));
    auto rest = take(stream_cdr(s), len - 1);
    result.insert(result.end(), rest.begin(), rest.end());
    return result;
}

// 辅助函数,用于打印 vector
template<typename T>
void print_vector(const std::vector<T>& vec) {
    std::cout << "[ ";
    for (const auto& item : vec) {
        std::cout << item << " ";
    }
    std::cout << "]" << std::endl;
}

// =================================================================
// Part 4: 高阶函数和函数组合
//
// Scheme: (stream-map operator a), (combine f g), (tunnel list-of-params)
// C++:    模板函数, std::function, 和折叠表达式 (或递归)
// =================================================================

// (stream-map operator a)
template<typename Func, typename T>
Stream<typename std::invoke_result<Func, T>::type>
stream_map(Func&& op, Stream<T> s) {
    using ResultType = typename std::invoke_result<Func, T>::type;
    if (!s) {
        return nullptr;
    }
    return stream_cons<ResultType>(
            op(stream_car(s)),
            [op = std::forward<Func>(op), s]() {
                return stream_map(op, stream_cdr(s));
            }
    );
}

// (combine f g) -> 返回 g(f(x))
// 我们需要 std::function 来抹除具体 lambda 类型
template <typename T>
using Func = std::function<T(T)>;

Func<double> combine(const Func<double>& f, const Func<double>& g) {
    return [f, g](double x) { return g(f(x)); };
}

// (tunnel list-of-params) -> 将一串函数组合起来
// C++17 折叠表达式可以优雅地实现,但为了更贴近 Scheme 的递归,我们用递归
Func<double> tunnel(const std::vector<Func<double>>& funcs) {
    if (funcs.empty()) {
        // 返回一个恒等函数
        return [](double x) { return x; };
    }
    if (funcs.size() == 1) {
        return funcs[0];
    }

    // 递归地组合
    // (combine f (combine g h)) ...
    Func<double> combined_rest = tunnel(std::vector<Func<double>>(funcs.begin() + 1, funcs.end()));
    return combine(funcs[0], combined_rest);
}

// =================================================================
// Part 5: 最终的例子
//
// Scheme: (define example (stream-map (tunnel ...) nature-numbers))
// =================================================================

int main() {
    // (define nature-numbers (int-from 1))
    auto nature_numbers = int_from(1);

    std::cout << "Taking first 10 natural numbers:" << std::endl;
    print_vector(take(nature_numbers, 10)); // 输出: [ 1 2 3 4 5 6 7 8 9 10 ]
    std::cout << std::endl;

    // (tunnel (lambda (x) (* x x)) (lambda (x) (+ x 3)) (lambda (x) (log x 2)))
    // 组合 f(x) = log2(x^2 + 3)
    std::vector<Func<double>> func_list = {
            [](double x) { return x * x; },        // x -> x^2
            [](double x) { return x + 3; },        // y -> y + 3
            [](double x) { return std::log2(x); }  // z -> log2(z)
    };

    auto composed_func = tunnel(func_list);

    // (define example (stream-map ... nature-numbers))
    // 注意:需要先将 long long 流转为 double 流
    auto nature_numbers_double = stream_map([](long long x){ return static_cast<double>(x); }, nature_numbers);
    auto example_stream = stream_map(composed_func, nature_numbers_double);

    std::cout << "Taking first 5 elements from the 'example' stream:" << std::endl;
    // (take example 5)
    print_vector(take(example_stream, 5));

    // 计算结果验证:
    // x=1: log2(1*1 + 3) = log2(4) = 2
    // x=2: log2(2*2 + 3) = log2(7) approx 2.807
    // x=3: log2(3*3 + 3) = log2(12) approx 3.585
    // x=4: log2(4*4 + 3) = log2(19) approx 4.248
    // x=5: log2(5*5 + 3) = log2(28) approx 4.807
    // 输出: [ 2 2.80735 3.58496 4.24793 4.80735 ]

    return 0;
}