全面超越标准库 bitset 手把手教程

· · 科技·工程

省流:实现了速度约是 std::bitset 两倍的向量化 bitset,代码在参考实现部分。

说在前面

在算法竞赛中,bitset 一直是常数优化的利器。无论是图论中的传递闭包,状态压缩动态规划中的可行性转移,还是字符串匹配中的 Shift-And 算法,使用 bitset 都能通过将 w=64 个状态进行压缩,达到理论上 \frac 1w 的优秀优化效果。同时近年来考察 bitset 使用的题目逐渐增多,两年省选就各有一道(即 [省选联考 2025] 追忆、[省选联考 2026] 夜空),这使得 bitset 的地位愈发重要。

大部分人追求代码简洁,选择了标准库提供的 std::bitset。然而,随着现代 CPU 架构的演进和竞赛题目时限的日益内卷,标准库的 std::bitset 开始暴露出两个令人头疼的性能瓶颈:

  1. SIMD 向量化利用率低下:现代 x86 CPU 普遍支持 AVX2 指令集,能够在一个时钟周期内处理 256 位的数据。但受限于编译器的自动向量化能力和标准库的保守实现,std::bitset 的底层往往只能依赖标量指令(逐个 64 位处理,甚至部分实现是 32 位的),白白浪费了 CPU 庞大的并行吞吐能力。

  2. 嵌套运算的临时对象灾难:在处理如 A = (B & C) ^ (~D | E) 这样的复杂逻辑表达式时,C++ 的运算符重载机制会为每一步二元运算生成一个完整的临时对象,并伴随着多次对内存的完整遍历。这不仅导致了极高的内存带宽开销,还严重破坏了 CPU 的缓存命中率。

为了打破这些桎梏,获得更高的效率,我们不能再依赖标准库。本文将带你从零开始,手写一个专为算法竞赛打造的极速静态 bitset。我们将直接介入底层,利用 AVX2 指令集实现真正的 256 位宽并行计算。同时,我们将引入表达式模板这一技术,通过编译期构建抽象语法树实现延迟计算,将复杂的复合运算折叠成一次完美的单趟内存遍历。

接下来是教程。

记号约定

为了方便描述,我们定义:

基础布局

要实现一个极致性能的 bitset,第一步是设计其内存布局。与 std::bitset 类似,我们使用 uint64_t(字)数组来存储位信息。但为了充分利用 AVX2 指令集,我们需要引入 256 位的 SIMD 向量类型 __m256i(块)。

template <size_t kSize>
class Bitset {
protected:
    typedef uint64_t Word;
    typedef __m256i Block;

    static const size_t kWordSize = sizeof(Word) * CHAR_BIT;
    static const size_t kBlockSize = sizeof(Block) / sizeof(Word);
    static const size_t kWordCount = (kSize + kWordSize - 1) / kWordSize;
    static const size_t kBlockCount = (kWordCount + kBlockSize - 1) / kBlockSize;

    alignas(Block) Word mData[kBlockCount * kBlockSize]{};

    void trim() {
        if (kSize % kWordSize) mData[kWordCount - 1] &= (1ULL << (kSize % kWordSize)) - 1;
        if (kBlockCount * kBlockSize > kWordCount) memset(mData + kWordCount, 0, (kBlockCount * kBlockSize - kWordCount) * sizeof(Word));
    }

    void normalize() {
        if (kSize % kWordSize) mData[kWordCount - 1] &= (1ULL << (kSize % kWordSize)) - 1;
        if (kBlockCount * kBlockSize > kWordCount) memset(mData + kWordCount, 0, (kBlockCount * kBlockSize - kWordCount) * sizeof(Word));
    }

    inline Word wordAt(size_t position) const {
        return mData[position];
    }

    inline Block blockAt(size_t position) const {
        return _mm256_load_si256((const Block*)&mData[position * kBlockSize]);
    }

    inline void wordSet(size_t position, Word value) {
        mData[position] = value;
    }

    inline void blockSet(size_t position, Block value) {
        _mm256_store_si256((Block*)&mData[position * kBlockSize], value);
    }
};

这里有几个关键的设计细节:

  1. 基本单位:Word64 位的标量单位,而 Block256 位的向量单位。我们预先计算好所需的 Word 数量和 Block 数量。注意,数组的实际分配大小是 kBlockCount * kBlockSize,这是在末尾补满一个块,保证我们在进行 SIMD 批量操作时,永远不会发生内存越界,免去一些边界讨论。

  2. 内存对齐:这是向量化编程中极其重要的一环。我们使用 alignas(Block) 这一 C++11 关键字强制 mData 数组在内存中按照 Block 的对齐要求——32 字节——对齐。有了这个保证,我们在后续代码中就可以安全地使用 _mm256_load_si256_mm256_store_si256 这类要求内存对齐的指令。相比于非对齐的 loadu/storeu,对齐的内存访问在某些微架构上能带来可观的性能提升。

  3. 边界清理:由于我们在末尾分配了一些闲置位,且某些批量操作会把超出 kSize 范围的无效位也置为 1,这会严重干扰后续的操作。因此,我们需要一个 trim() / normalize() 函数,在每次修改操作完成后,将超出 kSize 的无效位强制清零。具体来说,首先用位掩码清理最后一个有效 Word 中多余的高位;然后用 memset 将后面纯粹作为填充的 Word 全部清零。

  4. 字块访问:定义了几个函数来方便我们访问或修改指定的字或块。其中 _mm256_load_si256 以及 _mm256_store_si256 是 AVX2 指令集中的指令,位于 <immintrin.h> 头文件中,可以查询 Intel Intrinsics Guide 获取详细信息。本文中所有用到的 AVX2 指令及其作用见下,其中等价代码仅作辅助理解使用:

简单操作向量化

有了上面这些指令,我们就可以向量化一些简单的操作了。

逻辑运算

即按位与、或、异或、差。这是最容易向量化的部分了,两个 bitset 同时遍历每个块,然后通过 _mm256_and_si256_mm256_or_si256_mm256_xor_si256_mm256_andnot_si256 计算结果即可。注意集合差对应的是反的与非。

friend Bitset operator&(const Bitset& A, const Bitset& B) {
    Bitset result;
    for (size_t i = 0; i != kBlockCount; ++i)
        result.blockSet(i, _mm256_and_si256(B.blockAt(i), A.blockAt(i)));
    return result;
}

friend Bitset operator|(const Bitset& A, const Bitset& B) {
    Bitset result;
    for (size_t i = 0; i != kBlockCount; ++i)
        result.blockSet(i, _mm256_or_si256(B.blockAt(i), A.blockAt(i)));
    return result;
}

friend Bitset operator^(const Bitset& A, const Bitset& B) {
    Bitset result;
    for (size_t i = 0; i != kBlockCount; ++i)
        result.blockSet(i, _mm256_xor_si256(B.blockAt(i), A.blockAt(i)));
    return result;
}

friend Bitset operator-(const Bitset& A, const Bitset& B) {
    Bitset result;
    for (size_t i = 0; i != kBlockCount; ++i)
        result.blockSet(i, _mm256_andnot_si256(B.blockAt(i), A.blockAt(i)));
    return result;
}

关系查询

判断两个 bitset 的关系,注意这里的 >=><=< 分别指的是包含、真包含、子集、真子集,也就是 \supseteq,\supset,\subseteq,\subset 这四个关系,并非二进制数比较。

相等判断可以使用 memcmp 实现,但是由于我们保证内存对齐所以手写向量化可能效率更高,具体来说我们求两个 bitset 块对位异或,然后用 _mm256_testz_si256 判断其是否全 0。包含判断可以按块 _mm256_testc_si256,真包含在包含的基础上判断不等即可。

friend bool operator==(const Bitset& A, const Bitset& B) {
    for (size_t i = 0; i != kBlockCount; ++i) {
        Block bxa = _mm256_xor_si256(B.blockAt(i), A.blockAt(i));
        if (!_mm256_testz_si256(bxa, bxa)) return false;
    }
    return true;
}

friend bool operator!=(const Bitset& A, const Bitset& B) {
    for (size_t i = 0; i != kBlockCount; ++i) {
        Block bxa = _mm256_xor_si256(B.blockAt(i), A.blockAt(i));
        if (!_mm256_testz_si256(bxa, bxa)) return true;
    }
    return false;
}

friend bool operator<=(const Bitset& A, const Bitset& B) {
    for (size_t i = 0; i != kBlockCount; ++i)
        if (!_mm256_testc_si256(B.blockAt(i), A.blockAt(i)))
            return false;
    return true;
}

friend bool operator>=(const Bitset& A, const Bitset& B) {
    for (size_t i = 0; i != kBlockCount; ++i)
        if (!_mm256_testc_si256(A.blockAt(i), B.blockAt(i)))
            return false;
    return true;
}

friend bool operator<(const Bitset& A, const Bitset& B) {
    bool different = false;
    for (size_t i = 0; i != kBlockCount; ++i) {
        Block a = A.blockAt(i), b = B.blockAt(i);
        if (!_mm256_testc_si256(b, a)) return false;
        different |= !_mm256_testc_si256(a, b);
    }
    return different;
}

friend bool operator>(const Bitset& A, const Bitset& B) {
    bool different = false;
    for (size_t i = 0; i != kBlockCount; ++i) {
        Block a = A.blockAt(i), b = B.blockAt(i);
        if (!_mm256_testc_si256(a, b)) return false;
        different |= !_mm256_testc_si256(b, a);
    }
    return different;
}

区间修改

区间赋值,区间翻转。如果逐位操作,显然太慢了。标准的向量化区间操作通常采用“头部——躯干——尾部”的三段式处理策略。

首先是头部,position 可能并不对齐到 64 位字的边界。我们需要先处理第一个不完整的字,通过位移生成掩码,只修改该字中的部分位,使其对齐到下一个 Word 的起始位置。然后躯干,这是发挥向量化威力的地方。对于中间成块的完整数据,我们直接使用 _mm256_storeu_si256 每次写入 256 位(全 1 或全 0),或者用 _mm256_xor_si256 每次翻转 256 位。剩余不足 256 位但完整的 64 位字,则用普通的标量赋值处理。

最后剩下的不足 64 位的尾部,再次生成掩码进行局部修改。

void set(size_t position, size_t length) {
    if (position + length > kSize || !length) return;
    size_t wordIndex = position / kWordSize, bitInWord = position % kWordSize;
    if (bitInWord) {
        size_t headLength = kWordSize - bitInWord < length ? kWordSize - bitInWord : length;
        mData[wordIndex++] |= ((1ULL << headLength) - 1) << bitInWord;
        length -= headLength;
    }
    for (const Block value = _mm256_set1_epi64x(-1); length >= sizeof(Block) * CHAR_BIT; ) {
        _mm256_storeu_si256((Block*)&mData[wordIndex], value);
        length -= sizeof(Block) * CHAR_BIT;
        wordIndex += kBlockSize;
    }
    while (length >= kWordSize) {
        mData[wordIndex++] = -1ULL;
        length -= kWordSize;
    }
    mData[wordIndex] |= (1ULL << length) - 1;
}

void unset(size_t position, size_t length) {
    if (position + length > kSize || !length) return;
    size_t wordIndex = position / kWordSize, bitInWord = position % kWordSize;
    if (bitInWord) {
        size_t headLength = kWordSize - bitInWord < length ? kWordSize - bitInWord : length;
        mData[wordIndex++] &= ~(((1ULL << headLength) - 1) << bitInWord);
        length -= headLength;
    }
    for (const Block value = _mm256_setzero_si256(); length >= sizeof(Block) * CHAR_BIT; ) {
        _mm256_storeu_si256((Block*)&mData[wordIndex], value);
        length -= sizeof(Block) * CHAR_BIT;
        wordIndex += kBlockSize;
    }
    while (length >= kWordSize) {
        mData[wordIndex++] = 0ULL;
        length -= kWordSize;
    }
    mData[wordIndex] &= ~((1ULL << length) - 1);
}

void flip(size_t position, size_t length) {
    if (position + length > kSize || !length) return;
    size_t wordIndex = position / kWordSize, bitInWord = position % kWordSize;
    if (bitInWord) {
        size_t headLength = kWordSize - bitInWord < length ? kWordSize - bitInWord : length;
        mData[wordIndex++] ^= ((1ULL << headLength) - 1) << bitInWord;
        length -= headLength;
    }
    for (const Block value = _mm256_set1_epi64x(-1); length >= sizeof(Block) * CHAR_BIT; ) {
        _mm256_storeu_si256((Block*)&mData[wordIndex], _mm256_xor_si256(_mm256_loadu_si256((const Block*)&mData[wordIndex]), value));
        length -= sizeof(Block) * CHAR_BIT;
        wordIndex += kBlockSize;
    }
    while (length >= kWordSize) {
        mData[wordIndex++] ^= -1ULL;
        length -= kWordSize;
    }
    mData[wordIndex] ^= (1ULL << length) - 1;
}

状态查询

查询 bitset 是否全 1、全 0、存在 1。借助 AVX2 提供的测试指令,我们可以极其高效地完成这些查询。_mm256_testz_si256(A, B) 可以判断 (A & B) == 0,如果我们将 B 置为全 1 向量,这就相当于判断 A == 0,可以利用其实现 none 以及 any。至于 all 的判断,可以用 _mm256_testc_si256(A, B) 判断 (~A & B) == 0,如果我们将 B 置为全 1 向量,这就相当于判断 A 是全 1 向量。

但是请一定要注意特判末尾填充位!如果 kSize 不是 256 的整数倍,最后一个块中会包含无效的填充位。这些填充位在 trim() 的作用下永远是 0。如果我们直接对最后一个块调用 _mm256_testc_si256 来判断 all(),由于填充位是 0,这显然会返回错误的结果。因此,对于最后一个不完整的块,必须动态构造一个精确匹配剩余有效位的掩码。

bool none() const {
    const size_t fullBlocks = kSize / (sizeof(Block) * CHAR_BIT);
    for (size_t i = 0; i != fullBlocks; ++i) {
        Block mask = _mm256_set1_epi8(-1), block = blockAt(i);
        if (!_mm256_testz_si256(block, mask)) return false;
    }
    if (size_t remaining = kSize % (sizeof(Block) * CHAR_BIT)) {
        Word maskArray[kBlockSize] = {};
        for (size_t i = 0; i != kBlockSize; ++i) {
            if (remaining >= kWordSize) {
                maskArray[i] = ~0ULL;
                remaining -= kWordSize;
            } else {
                maskArray[i] = (1ULL << remaining) - 1;
                break;
            }
        }
        Block mask = _mm256_loadu_si256((const Block*)maskArray), block = blockAt(fullBlocks);
        if (!_mm256_testz_si256(block, mask)) return false;
    }
    return true;
}

bool any() const {
    const size_t fullBlocks = kSize / (sizeof(Block) * CHAR_BIT);
    for (size_t i = 0; i != fullBlocks; ++i) {
        Block mask = _mm256_set1_epi8(-1), block = blockAt(i);
        if (!_mm256_testz_si256(block, mask)) return true;
    }
    if (size_t remaining = kSize % (sizeof(Block) * CHAR_BIT)) {
        Word maskArray[kBlockSize] = {};
        for (size_t i = 0; i != kBlockSize; ++i) {
            if (remaining >= kWordSize) {
                maskArray[i] = ~0ULL;
                remaining -= kWordSize;
            } else {
                maskArray[i] = (1ULL << remaining) - 1;
                break;
            }
        }
        Block mask = _mm256_loadu_si256((const Block*)maskArray), block = blockAt(fullBlocks);
        if (!_mm256_testz_si256(block, mask)) return true;
    }
    return false;
}

bool all() const {
    const size_t fullBlocks = kSize / (sizeof(Block) * CHAR_BIT);
    for (size_t i = 0; i != fullBlocks; ++i) {
        Block mask = _mm256_set1_epi8(-1), block = blockAt(i);
        if (!_mm256_testc_si256(block, mask)) return false;
    }
    if (size_t remaining = kSize % (sizeof(Block) * CHAR_BIT)) {
        Word maskArray[kBlockSize] = {};
        for (size_t i = 0; i != kBlockSize; ++i) {
            if (remaining >= kWordSize) {
                maskArray[i] = ~0ULL;
                remaining -= kWordSize;
            } else {
                maskArray[i] = (1ULL << remaining) - 1;
                break;
            }
        }
        Block mask = _mm256_loadu_si256((const Block*)maskArray), block = blockAt(fullBlocks);
        if (!_mm256_testc_si256(block, mask)) return false;
    }
    return true;
}

位置查找

以查找下一个设置为 1 的位置为例。在部分如匈牙利算法、求最大团的图论算法中,我们经常需要遍历 bitset 中所有为 1 的位。这里我们可以模仿 libstdc++ std::bitset_Find_next 的实现方式,并加以向量化。

首先找到查找位置所在的字,通过掩码消掉该查找位置之前的所有位,然后检查其是否全 0,若否则直接使用 __builtin_ctzll 找到第一个设置为 1 的位置并返回即可。然后是向量化通用思路,把该字所在的块中剩余的字全部处理后整块处理。整块处理时,类比 none() 的实现方式使用 _mm256_testz_si256 快速判断当前块是否全 0,若否则遍历这个块中的每个字找到第一个非全 0 的字然后用 __builtin_ctzll 找到第一个设置为 1 的具体位置返回即可。

0 的方式与之类似,其实本质上就是将 bitset 整体取反后查 1,只需要在判断前将取出的字或块按位取反即可,在此不过多赘述。

size_t findFirstSet(size_t position) const {
    if (position >= kSize) return size_t(-1);
    size_t wordIndex = position / kWordSize, bitInWord = position % kWordSize;
    size_t blockIndex = wordIndex / kBlockSize, wordInBlock = wordIndex % kBlockSize;
    if (Word word = mData[wordIndex] & ~((1ULL << bitInWord) - 1)) {
        size_t result = wordIndex * kWordSize + size_t(__builtin_ctzll(word));
        return result < kSize ? result : size_t(-1);
    }
    while (++wordInBlock != kBlockSize) {
        size_t current = blockIndex * kBlockSize + wordInBlock;
        if (current >= kWordCount) break;
        if (Word word = mData[current]) {
            size_t result = current * kWordSize + __builtin_ctzll(word);
            return result < kSize ? result : size_t(-1);
        }
    }
    while (++blockIndex != kBlockCount) {
        Block mask = _mm256_set1_epi64x(-1), block = blockAt(blockIndex);
        if (_mm256_testz_si256(block, mask)) continue;
        for (size_t i = 0; i != kBlockSize; ++i) {
            size_t current = blockIndex * kBlockSize + i;
            if (current >= kWordCount) break;
            if (Word word = mData[current]) {
                size_t result = current * kWordSize + __builtin_ctzll(word);
                return result < kSize ? result : size_t(-1);
            }
        }
    }
    return size_t(-1);
}

size_t findFirstUnset(size_t position) const {
    if (position >= kSize) return size_t(-1);
    size_t wordIndex = position / kWordSize, bitInWord = position % kWordSize;
    size_t blockIndex = wordIndex / kBlockSize, wordInBlock = wordIndex % kBlockSize;
    if (Word word = ~mData[wordIndex] & ~((1ULL << bitInWord) - 1)) {
        size_t result = wordIndex * kWordSize + size_t(__builtin_ctzll(word));
        return result < kSize ? result : size_t(-1);
    }
    while (++wordInBlock != kBlockSize) {
        size_t current = blockIndex * kBlockSize + wordInBlock;
        if (current >= kWordCount) break;
        if (Word word = ~mData[current]) {
            size_t result = current * kWordSize + __builtin_ctzll(word);
            return result < kSize ? result : size_t(-1);
        }
    }
    while (++blockIndex != kBlockCount) {
        Block mask = _mm256_set1_epi64x(-1), block = blockAt(blockIndex);
        if (_mm256_testc_si256(block, mask)) continue;
        for (size_t i = 0; i != kBlockSize; ++i) {
            size_t current = blockIndex * kBlockSize + i;
            if (current >= kWordCount) break;
            if (Word word = ~mData[current]) {
                size_t result = current * kWordSize + __builtin_ctzll(word);
                return result < kSize ? result : size_t(-1);
            }
        }
    }
    return size_t(-1);
}

复杂操作向量化

部分操作向量化较为复杂,需要一定的思考。

左右移位

以左移为例。这是向量化 bitset 最难啃的骨头,为什么?因为 AVX2 没有提供跨 64 位边界的整体移位指令。_mm256_slli_epi64 只能让 464 位的字各自在内部移位,无法把低位字溢出的位补到高位字中。为了解决这个问题,我们将总移位量 step 拆分为 wordShift = step / kWordSize 整字偏移量,以及 bitShift = step % kWordSize 位偏移量。然后分类讨论:

首先是 bitShift == 0 的情况,也就是整字偏移。这种情况下我们根本不需要位运算,直接使用 memmove 搬运内存即可。注意由于复制源区间和目标区间存在重叠,所以这里不能使用 memcpy。最后用 memset 把移动掉的位置清零即可。

当存在位偏移时,我们做的事情是:将当前字左移 bitShift 位,并提取相邻低地址字溢出的高 kWordSize - bitShift 位,将两者拼接,也就是或起来。为了向量化这个过程,我们使用非对齐加载 _mm256_loadu_si256,故意错开一个字的位置加载相邻的数据块,然后分别进行向量左移和右移,最后拼接。左右移使用 _mm256_sll_epi64_mm256_srl_epi64 实现,通过 _mm_cvtsi32_si128 将偏移量 bitShift 转成要求的 __m128i 形式传进去。

需要注意的是,左移必须从高地址向低地址逆向遍历,右移必须从低向高遍历,这是为了防止源数据在被读取前就被覆盖。同时由于我们错开了一个字,所以我们需要使用非对齐加载。在处理完所有块之后,剩余不足整块的若干个字,我们回退到普通的标量循环进行拼接即可。最后别忘了调用 trim() 清理边界。

Bitset& operator<<=(size_t step) {
    if (step >= kSize) return unset(), *this;
    size_t wordShift = step / kWordSize, bitShift = step % kWordSize;
    if (!bitShift) {
        memmove(mData + wordShift, mData, (kWordCount - wordShift) * sizeof(Word));
        memset(mData, 0, wordShift * sizeof(Word));
    } else {
        size_t remaining = kWordCount - wordShift - 1;
        Word *destination = mData + kWordCount, *source = destination - wordShift;
        __m128i lShift = _mm_cvtsi32_si128(int(bitShift)), rShift = _mm_cvtsi32_si128(int(kWordSize - bitShift));
        while (remaining >= kBlockSize) {
            destination -= kBlockSize, source -= kBlockSize;
            Block low = _mm256_srl_epi64(_mm256_loadu_si256((const Block*)&source[-1]), rShift);
            Block high = _mm256_sll_epi64(_mm256_loadu_si256((const Block*)&source[0]), lShift);
            _mm256_storeu_si256((Block*)destination, _mm256_or_si256(low, high));
            remaining -= kBlockSize;
        }
        while (remaining) {
            --destination, --source;
            *destination = (source[0] << bitShift) | (source[-1] >> (kWordSize - bitShift));
            --remaining;
        }
        *--destination = *--source << bitShift;
        memset(mData, 0, wordShift * sizeof(Word));
    }
    trim();
    return *this;
}

Bitset& operator>>=(size_t step) {
    if (step >= kSize) return unset(), *this;
    size_t wordShift = step / kWordSize, bitShift = step % kWordSize;
    if (!bitShift) {
        memmove(mData, mData + wordShift, (kWordCount - wordShift) * sizeof(Word));
        memset(mData + kWordCount - wordShift, 0, wordShift * sizeof(Word));
    } else {
        size_t remaining = kWordCount - wordShift - 1;
        Word *destination = mData, *source = mData + wordShift;
        __m128i rShift = _mm_cvtsi32_si128(int(bitShift)), lShift = _mm_cvtsi32_si128(int(kWordSize - bitShift));
        while (remaining >= kBlockSize) {
            Block low = _mm256_srl_epi64(_mm256_loadu_si256((const Block*)&source[0]), rShift);
            Block high = _mm256_sll_epi64(_mm256_loadu_si256((const Block*)&source[1]), lShift);
            _mm256_storeu_si256((Block*)destination, _mm256_or_si256(low, high));
            destination += kBlockSize, source += kBlockSize;
            remaining -= kBlockSize;
        }
        while (remaining) {
            *destination = (source[0] >> bitShift) | (source[1] << (kWordSize - bitShift));
            ++destination, ++source;
            --remaining;
        }
        *destination = *source >> bitShift;
        memset(mData + kWordCount - wordShift, 0, wordShift * sizeof(Word));
    }
    trim();
    return *this;
}

个数统计

标量 bitset 的 popcnt 方法是对每个字调用 __builtin_popcountll 并求和,可惜 AVX2 指令集中并没有现成的 popcnt 指令。类比常规 popcnt 的查表法,联想到 _mm256_shuffle_epi8 的查表功能,以此为突破口尝试实现。

首先 _mm256_shuffle_epi8 的查表是每个 8 位整数的低 4 位作为索引,于是我们对所有 4 位整数的 popcnt 打表,得到查找表 table,通过掩码消去 block 中每个 8 位整数的高四位避免其最高位产生的影响,然后 _mm256_shuffle_epi8(table, block) 得到的结果 low 的每个 8 位整数就存储了 block 中的对应 8 位整数的低 4 位部分的 popcnt。使用 _mm256_srli_epi16block 右移 4 位重复一遍上述过程得到的结果 high 就存储了 block 中的对应 8 位整数的高 4 位部分的 popcnt。

得到 lowhigh 之后,使用 _mm256_add_epi8 将两者对位相加,所得结果 sum 的每个 8 位整数就存储了 block 中对应的 8 位整数的 popcnt。接下来我们想对所有 8 位整数求和,然而指令集中没有直接实现这一功能的指令,唯一相关的指令是 _mm256_sad_epu8‌ 实现的 \sum|a_i-b_i|。但是我们可以令 \forall b_i=0,也就是让 sum 和一个全 0 向量做 SAD,即 _mm256_sad_epu8(sum, _mm256_setzero_si256()),所得结果的每个 64 位整数就存储了 sum 中对应 88 位整数的和,称其为这个块的 popcnt 结果向量。

接着,我们开一个向量 result 作为累加器。每当我们得到一个块的 popcnt 结果向量后,使用 _mm256_add_epi64 将结果向量对位累加到 result 上。最后将 result_mm256_storeu_si256 存到 464 位整数组成的数组上,对这个数组求和即为整个 bitset 的 popcnt 了。

size_t count() const {
    const Block mask = _mm256_set1_epi8(0x0F);
    const Block table = _mm256_setr_epi8(
        0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
        0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4
    );
    Block result = _mm256_setzero_si256();
    for (size_t i = 0; i != kBlockCount; ++i) {
        Block block = blockAt(i);
        Block low = _mm256_shuffle_epi8(table, _mm256_and_si256(block, mask));
        Block high = _mm256_shuffle_epi8(table, _mm256_and_si256(_mm256_srli_epi16(block, 4), mask));
        result = _mm256_add_epi64(result, _mm256_sad_epu8(_mm256_add_epi8(low, high), _mm256_setzero_si256()));
    }
    Word parts[kBlockSize];
    _mm256_storeu_si256((Block*)parts, result);
    return size_t(parts[0] + parts[1] + parts[2] + parts[3]);
}

size_t popcnt() const {
    const Block mask = _mm256_set1_epi8(0x0F);
    const Block table = _mm256_setr_epi8(
        0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
        0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4
    );
    Block result = _mm256_setzero_si256();
    for (size_t i = 0; i != kBlockCount; ++i) {
        Block block = blockAt(i);
        Block low = _mm256_shuffle_epi8(table, _mm256_and_si256(block, mask));
        Block high = _mm256_shuffle_epi8(table, _mm256_and_si256(_mm256_srli_epi16(block, 4), mask));
        result = _mm256_add_epi64(result, _mm256_sad_epu8(_mm256_add_epi8(low, high), _mm256_setzero_si256()));
    }
    Word parts[kBlockSize];
    _mm256_storeu_si256((Block*)parts, result);
    return size_t(parts[0] + parts[1] + parts[2] + parts[3]);
}

嵌套运算

以上实现的向量化 Bitset 在单一操作上的表现十分优秀,但是这个写法在类似 A = ~B & (C | D) 的嵌套表达式上会暴露出严重的性能缺陷:

  1. 临时对象泛滥:计算 ~B 会生成一个临时 bitset,计算 C | D 会生成第二个,最后做 & 又会生成第三个。
  2. 多次遍历内存:上述每一步都会触发一次对整个底层数组的完整遍历,导致极高的内存带宽开销和缓存丢失。

如果手写,我们可以通过循环融合,也就是手写 A[i] = ~B[i] & (C[i] | D[i]) 来避免这一问题,实现单趟遍历。但每次都手写显然违背了封装的初衷。有没有办法让编译器自动帮我们完成这种“循环融合”呢?答案就是 C++ 模板元编程的黑魔法——表达式模板。

表达式模板的核心思想是:重载运算符时不进行实际计算,而是返回一个轻量级的对象,用于在编译期记录表达式的抽象语法树。直到最终赋值给目标对象时,才触发真正的计算。其中语法树是一棵二叉树,其叶子节点代表一个 bitset,非叶子节点有一个二元运算表示“此节点代表的 bitset 为其左右儿子代表的 bitset 做此二元运算后的结果”。对一个表达式构建出语法树后,求出根节点代表的 bitset 即为此表达式的值。

首先我们定义一个基类 Expression,所有继承自此类型的类型均被视为表达式树上的节点。同时我们在基类提供一个向其后代转化的方法。这里利用 CRTP 即奇异递归模板模式来实现静态多态,避免虚函数带来的运行时开销。我们的 Bitset 类也会继承自 Expression<Bitset<kSize>>,使其成为语法树的叶子节点。接下来,我们为每一种运算定义一个节点结构体:

template <typename E>
struct Expression {
    inline const E& operator()() const {
        return static_cast<const E&>(*this);
    }
};

template <typename E>
struct BitsetFlip : Expression<BitsetFlip<E>> {
    const E& mE;

    BitsetFlip(const E& e) : mE{e} {}

    inline __m256i blockAt(size_t i) const {
        return _mm256_xor_si256(mE.blockAt(i), _mm256_set1_epi64x(-1));
    }
};

template <typename E>
struct BitsetNot : Expression<BitsetNot<E>> {
    const E& mE;

    BitsetNot(const E& e) : mE{e} {}

    inline __m256i blockAt(size_t i) const {
        return _mm256_xor_si256(mE.blockAt(i), _mm256_set1_epi64x(-1));
    }
};

template <typename XE, typename YE>
struct BitsetAnd : Expression<BitsetAnd<XE, YE>> {
    const XE& mXE;
    const YE& mYE;

    BitsetAnd(const XE& xE, const YE& yE) : mXE{xE}, mYE{yE} {}

    inline __m256i blockAt(size_t i) const {
        return _mm256_and_si256(mYE.blockAt(i), mXE.blockAt(i));
    }
};

template <typename XE, typename YE>
struct BitsetOr : Expression<BitsetOr<XE, YE>> {
    const XE& mXE;
    const YE& mYE;

    BitsetOr(const XE& xE, const YE& yE) : mXE{xE}, mYE{yE} {}

    inline __m256i blockAt(size_t i) const {
        return _mm256_or_si256(mYE.blockAt(i), mXE.blockAt(i));
    }
};

template <typename XE, typename YE>
struct BitsetXor : Expression<BitsetXor<XE, YE>> {
    const XE& mXE;
    const YE& mYE;

    BitsetXor(const XE& xE, const YE& yE) : mXE{xE}, mYE{yE} {}

    inline __m256i blockAt(size_t i) const {
        return _mm256_xor_si256(mYE.blockAt(i), mXE.blockAt(i));
    }
};

template <typename XE, typename YE>
struct BitsetAnt : Expression<BitsetAnt<XE, YE>> {
    const XE& mXE;
    const YE& mYE;

    BitsetAnt(const XE& xE, const YE& yE) : mXE{xE}, mYE{yE} {}

    inline __m256i blockAt(size_t i) const {
        return _mm256_andnot_si256(mYE.blockAt(i), mXE.blockAt(i));
    }
};

注意这些运算节点结构体内部只保留了左右子节点的常引用。blockAt(i) 定义了实际的计算逻辑,但是此时并未执行。接下来我们重载全局的运算符,让它们返回语法树节点,而不是具体的 bitset:

template <typename E>
inline BitsetFlip<E> operator~(const Expression<E>& e) {
    return BitsetFlip<E>(e());
}

template <typename E>
inline BitsetNot<E> operator!(const Expression<E>& e) {
    return BitsetNot<E>(e());
}

template <typename XE, typename YE>
inline BitsetAnd<XE, YE> operator&(const Expression<XE>& xE, const Expression<YE>& yE) {
    return BitsetAnd<XE, YE>(xE(), yE());
}

template <typename XE, typename YE>
inline BitsetOr<XE, YE> operator|(const Expression<XE>& xE, const Expression<YE>& yE) {
    return BitsetOr<XE, YE>(xE(), yE());
}

template <typename XE, typename YE>
inline BitsetXor<XE, YE> operator^(const Expression<XE>& xE, const Expression<YE>& yE) {
    return BitsetXor<XE, YE>(xE(), yE());
}

template <typename XE, typename YE>
inline BitsetAnt<XE, YE> operator-(const Expression<XE>& xE, const Expression<YE>& yE) {
    return BitsetAnt<XE, YE>(xE(), yE());
}

最后,我们在 bitset 中提供接受 Expression 的方法。

Bitset(const Bitset&) = default;
Bitset(Bitset&&) = default;

template <typename E>
Bitset(const Expression<E>& e) {
    for (size_t i = 0; i != kBlockCount; ++i)
        blockSet(i, e().blockAt(i));
    trim();
}

Bitset& operator=(const Bitset&) = default;
Bitset& operator=(Bitset&&) = default;

template <typename E>
Bitset& operator=(const Expression<E>& e) {
    for (size_t i = 0; i != kBlockCount; ++i)
        blockSet(i, e().blockAt(i));
    trim();
    return *this;
}

接下来我们以 A = ~B & (C | D) 为例讲解语法树的具体流程。

  1. ~B 调用 operator~,得到一个 BitsetFlip 对象。无整体赋值。
  2. (C | D) 调用 operator|,得到一个 BitsetOr 对象。无整体赋值。
  3. ~B & (C | D) 调用 operator&,得到一个 BitsetAnd 对象。无整体赋值。
  4. A = ~B & (C | D) 调用 operator=
    • 遍历每个块,调用 BitsetAndblockAt
    • BitsetAndblockAt 递归调用其左右子节点 BitsetFlipBitsetOrblockAt
    • BitsetFlipblockAt 递归调用其子节点 BblockAt 得到具体的值。
    • BitsetOrblockAt 递归调用其左右子节点 CDblockAt 得到具体的值。

由于所有的 blockAt 均被标记为 inline,所以这些调用会被全部内联!最终生成的代码等价于:

for (size_t i = 0; i != kBlockCount; ++i) {
    A.blockSet(i, _mm256_and_si256(
        _mm256_xor_si256(B.blockAt(i), _mm256_set1_epi64x(-1)),
        _mm256_or_si256(C.blockAt(i), D.blockAt(i))
    ));
}

这完美地实现了循环融合!没有临时对象,只有一次内存遍历,所有的计算被融合进了一条紧凑的向量化指令流中。这是一个非常优雅且高效的解决方案。但是请注意下面这段代码:

auto foo() {
    Bitset<16> result;
    return ~result;
}

int main() {
    auto tmp = foo();
    Bitset<16> bs = tmp;
}

程序会爆掉。怎么会这样?我们来分析一下。首先 ~result 的类型是 BitsetFlip,于是 foo 的返回值被 auto 推断为 BitsetFlip,所以 tmp 也被 auto 推断为 BitsetFlip 类型。但是,此时 tmp 保存的引用是 result 的,而 result 的生命周期已经结束了,tmp 保存的是悬垂引用,那么下面 bs = tmp 便成为了 UB。

那有人可能会想到一种解决方案:“把语法树节点的拷贝构造、移动构造、拷贝赋值、移动赋值全部删除,这样就不能写 auto tmp = foo() 了!”这种写法有一定的道理,但是 C++17 引入了强制拷贝消除,而 foo() 为纯右值,编译器会在 tmp 的位置上原位构造 BitsetFlip 对象,完全不需要复制构造和移动构造函数,因而这种方案行不通。

所以请不要在 Bitset 相关的表达式中使用 auto 进行类型推断,除非你非常清楚你在做什么

参考实现

见云剪贴板。

:::info[或者这里]

#ifndef BITSET_H
#define BITSET_H 202605L

#include <stdint.h>
#include <string.h>
#include <immintrin.h>

template <typename E>
struct Expression {
    inline const E& operator()() const {
        return static_cast<const E&>(*this);
    }
};

template <typename E>
struct BitsetFlip : Expression<BitsetFlip<E>> {
    const E& mE;

    BitsetFlip(const E& e) : mE{e} {}

    inline __m256i blockAt(size_t i) const {
        return _mm256_xor_si256(mE.blockAt(i), _mm256_set1_epi64x(-1));
    }
};

template <typename E>
struct BitsetNot : Expression<BitsetNot<E>> {
    const E& mE;

    BitsetNot(const E& e) : mE{e} {}

    inline __m256i blockAt(size_t i) const {
        return _mm256_xor_si256(mE.blockAt(i), _mm256_set1_epi64x(-1));
    }
};

template <typename XE, typename YE>
struct BitsetAnd : Expression<BitsetAnd<XE, YE>> {
    const XE& mXE;
    const YE& mYE;

    BitsetAnd(const XE& xE, const YE& yE) : mXE{xE}, mYE{yE} {}

    inline __m256i blockAt(size_t i) const {
        return _mm256_and_si256(mYE.blockAt(i), mXE.blockAt(i));
    }
};

template <typename XE, typename YE>
struct BitsetOr : Expression<BitsetOr<XE, YE>> {
    const XE& mXE;
    const YE& mYE;

    BitsetOr(const XE& xE, const YE& yE) : mXE{xE}, mYE{yE} {}

    inline __m256i blockAt(size_t i) const {
        return _mm256_or_si256(mYE.blockAt(i), mXE.blockAt(i));
    }
};

template <typename XE, typename YE>
struct BitsetXor : Expression<BitsetXor<XE, YE>> {
    const XE& mXE;
    const YE& mYE;

    BitsetXor(const XE& xE, const YE& yE) : mXE{xE}, mYE{yE} {}

    inline __m256i blockAt(size_t i) const {
        return _mm256_xor_si256(mYE.blockAt(i), mXE.blockAt(i));
    }
};

template <typename XE, typename YE>
struct BitsetAnt : Expression<BitsetAnt<XE, YE>> {
    const XE& mXE;
    const YE& mYE;

    BitsetAnt(const XE& xE, const YE& yE) : mXE{xE}, mYE{yE} {}

    inline __m256i blockAt(size_t i) const {
        return _mm256_andnot_si256(mYE.blockAt(i), mXE.blockAt(i));
    }
};

template <typename E>
inline BitsetFlip<E> operator~(const Expression<E>& e) {
    return BitsetFlip<E>(e());
}

template <typename E>
inline BitsetNot<E> operator!(const Expression<E>& e) {
    return BitsetNot<E>(e());
}

template <typename XE, typename YE>
inline BitsetAnd<XE, YE> operator&(const Expression<XE>& xE, const Expression<YE>& yE) {
    return BitsetAnd<XE, YE>(xE(), yE());
}

template <typename XE, typename YE>
inline BitsetOr<XE, YE> operator|(const Expression<XE>& xE, const Expression<YE>& yE) {
    return BitsetOr<XE, YE>(xE(), yE());
}

template <typename XE, typename YE>
inline BitsetXor<XE, YE> operator^(const Expression<XE>& xE, const Expression<YE>& yE) {
    return BitsetXor<XE, YE>(xE(), yE());
}

template <typename XE, typename YE>
inline BitsetAnt<XE, YE> operator-(const Expression<XE>& xE, const Expression<YE>& yE) {
    return BitsetAnt<XE, YE>(xE(), yE());
}

template <size_t kSize>
class Bitset : public Expression<Bitset<kSize>> {
protected:
    typedef uint64_t Word;
    typedef __m256i Block;

    static const size_t kWordSize = sizeof(Word) * CHAR_BIT;
    static const size_t kBlockSize = sizeof(Block) / sizeof(Word);
    static const size_t kWordCount = (kSize + kWordSize - 1) / kWordSize;
    static const size_t kBlockCount = (kWordCount + kBlockSize - 1) / kBlockSize;

    alignas(Block) Word mData[kBlockCount * kBlockSize]{};

public:
    Bitset() = default;
    ~Bitset() = default;

    inline Word wordAt(size_t position) const {
        return mData[position];
    }

    inline Block blockAt(size_t position) const {
        return _mm256_load_si256((const Block*)&mData[position * kBlockSize]);
    }

protected:
    inline void wordSet(size_t position, Word value) {
        mData[position] = value;
    }

    inline void blockSet(size_t position, Block value) {
        _mm256_store_si256((Block*)&mData[position * kBlockSize], value);
    }

    void trim() {
        if (kSize % kWordSize) mData[kWordCount - 1] &= (1ULL << (kSize % kWordSize)) - 1;
        if (kBlockCount * kBlockSize > kWordCount) memset(mData + kWordCount, 0, (kBlockCount * kBlockSize - kWordCount) * sizeof(Word));
    }

    void normalize() {
        if (kSize % kWordSize) mData[kWordCount - 1] &= (1ULL << (kSize % kWordSize)) - 1;
        if (kBlockCount * kBlockSize > kWordCount) memset(mData + kWordCount, 0, (kBlockCount * kBlockSize - kWordCount) * sizeof(Word));
    }

public:
    Bitset(const Bitset&) = default;
    Bitset(Bitset&&) = default;

    template <typename E>
    Bitset(const Expression<E>& e) {
        for (size_t i = 0; i != kBlockCount; ++i)
            blockSet(i, e().blockAt(i));
        trim();
    }

    Bitset& operator=(const Bitset&) = default;
    Bitset& operator=(Bitset&&) = default;

    template <typename E>
    Bitset& operator=(const Expression<E>& e) {
        for (size_t i = 0; i != kBlockCount; ++i)
            blockSet(i, e().blockAt(i));
        trim();
        return *this;
    }

    bool operator[](size_t position) const {
        if (position >= kSize) return false;
        return mData[position / kWordSize] >> position % kWordSize & 1;
    }

    bool operator()(size_t position) const {
        if (position >= kSize) return false;
        return mData[position / kWordSize] >> position % kWordSize & 1;
    }

    void set(size_t position) {
        if (position >= kSize) return;
        mData[position / kWordSize] |= 1ULL << position % kWordSize;
    }

    void unset(size_t position) {
        if (position >= kSize) return;
        mData[position / kWordSize] &= ~(1ULL << position % kWordSize);
    }

    void flip(size_t position) {
        if (position >= kSize) return;
        mData[position / kWordSize] ^= 1ULL << position % kWordSize;
    }

    size_t size() const {
        return kSize;
    }

    size_t length() const {
        return kSize;
    }

    size_t count() const {
        const Block mask = _mm256_set1_epi8(0x0F);
        const Block table = _mm256_setr_epi8(
            0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
            0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4
        );
        Block result = _mm256_setzero_si256();
        for (size_t i = 0; i != kBlockCount; ++i) {
            Block block = blockAt(i);
            Block low = _mm256_shuffle_epi8(table, _mm256_and_si256(block, mask));
            Block high = _mm256_shuffle_epi8(table, _mm256_and_si256(_mm256_srli_epi16(block, 4), mask));
            result = _mm256_add_epi64(result, _mm256_sad_epu8(_mm256_add_epi8(low, high), _mm256_setzero_si256()));
        }
        Word parts[kBlockSize];
        _mm256_storeu_si256((Block*)parts, result);
        return size_t(parts[0] + parts[1] + parts[2] + parts[3]);
    }

    size_t popcnt() const {
        const Block mask = _mm256_set1_epi8(0x0F);
        const Block table = _mm256_setr_epi8(
            0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
            0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4
        );
        Block result = _mm256_setzero_si256();
        for (size_t i = 0; i != kBlockCount; ++i) {
            Block block = blockAt(i);
            Block low = _mm256_shuffle_epi8(table, _mm256_and_si256(block, mask));
            Block high = _mm256_shuffle_epi8(table, _mm256_and_si256(_mm256_srli_epi16(block, 4), mask));
            result = _mm256_add_epi64(result, _mm256_sad_epu8(_mm256_add_epi8(low, high), _mm256_setzero_si256()));
        }
        Word parts[kBlockSize];
        _mm256_storeu_si256((Block*)parts, result);
        return size_t(parts[0] + parts[1] + parts[2] + parts[3]);
    }

    bool none() const {
        const size_t fullBlocks = kSize / (sizeof(Block) * CHAR_BIT);
        for (size_t i = 0; i != fullBlocks; ++i) {
            Block mask = _mm256_set1_epi8(-1), block = blockAt(i);
            if (!_mm256_testz_si256(block, mask)) return false;
        }
        if (size_t remaining = kSize % (sizeof(Block) * CHAR_BIT)) {
            Word maskArray[kBlockSize] = {};
            for (size_t i = 0; i != kBlockSize; ++i) {
                if (remaining >= kWordSize) {
                    maskArray[i] = ~0ULL;
                    remaining -= kWordSize;
                } else {
                    maskArray[i] = (1ULL << remaining) - 1;
                    break;
                }
            }
            Block mask = _mm256_loadu_si256((const Block*)maskArray), block = blockAt(fullBlocks);
            if (!_mm256_testz_si256(block, mask)) return false;
        }
        return true;
    }

    bool any() const {
        const size_t fullBlocks = kSize / (sizeof(Block) * CHAR_BIT);
        for (size_t i = 0; i != fullBlocks; ++i) {
            Block mask = _mm256_set1_epi8(-1), block = blockAt(i);
            if (!_mm256_testz_si256(block, mask)) return true;
        }
        if (size_t remaining = kSize % (sizeof(Block) * CHAR_BIT)) {
            Word maskArray[kBlockSize] = {};
            for (size_t i = 0; i != kBlockSize; ++i) {
                if (remaining >= kWordSize) {
                    maskArray[i] = ~0ULL;
                    remaining -= kWordSize;
                } else {
                    maskArray[i] = (1ULL << remaining) - 1;
                    break;
                }
            }
            Block mask = _mm256_loadu_si256((const Block*)maskArray), block = blockAt(fullBlocks);
            if (!_mm256_testz_si256(block, mask)) return true;
        }
        return false;
    }

    bool all() const {
        const size_t fullBlocks = kSize / (sizeof(Block) * CHAR_BIT);
        for (size_t i = 0; i != fullBlocks; ++i) {
            Block mask = _mm256_set1_epi8(-1), block = blockAt(i);
            if (!_mm256_testc_si256(block, mask)) return false;
        }
        if (size_t remaining = kSize % (sizeof(Block) * CHAR_BIT)) {
            Word maskArray[kBlockSize] = {};
            for (size_t i = 0; i != kBlockSize; ++i) {
                if (remaining >= kWordSize) {
                    maskArray[i] = ~0ULL;
                    remaining -= kWordSize;
                } else {
                    maskArray[i] = (1ULL << remaining) - 1;
                    break;
                }
            }
            Block mask = _mm256_loadu_si256((const Block*)maskArray), block = blockAt(fullBlocks);
            if (!_mm256_testc_si256(block, mask)) return false;
        }
        return true;
    }

    void set() {
        const Block value = _mm256_set1_epi64x(-1);
        for (size_t i = 0; i != kBlockCount; ++i)
            blockSet(i, value);
        trim();
    }

    void unset() {
        const Block value = _mm256_setzero_si256();
        for (size_t i = 0; i != kBlockCount; ++i)
            blockSet(i, value);
        trim();
    }

    void flip() {
        const Block value = _mm256_set1_epi64x(-1);
        for (size_t i = 0; i != kBlockCount; ++i)
            blockSet(i, _mm256_xor_si256(blockAt(i), value));
        trim();
    }

    template <typename E>
    Bitset& operator&=(const Expression<E>& e) {
        for (size_t i = 0; i != kBlockCount; ++i)
            blockSet(i, _mm256_and_si256(blockAt(i), e().blockAt(i)));
        trim();
        return *this;
    }

    template <typename E>
    Bitset& operator|=(const Expression<E>& e) {
        for (size_t i = 0; i != kBlockCount; ++i)
            blockSet(i, _mm256_or_si256(blockAt(i), e().blockAt(i)));
        trim();
        return *this;
    }

    template <typename E>
    Bitset& operator^=(const Expression<E>& e) {
        for (size_t i = 0; i != kBlockCount; ++i)
            blockSet(i, _mm256_xor_si256(blockAt(i), e().blockAt(i)));
        trim();
        return *this;
    }

    template <typename E>
    Bitset& operator-=(const Expression<E>& e) {
        for (size_t i = 0; i != kBlockCount; ++i)
            blockSet(i, _mm256_andnot_si256(e().blockAt(i), blockAt(i)));
        trim();
        return *this;
    }

    friend bool operator==(const Bitset& A, const Bitset& B) {
        for (size_t i = 0; i != kBlockCount; ++i) {
            Block bxa = _mm256_xor_si256(B.blockAt(i), A.blockAt(i));
            if (!_mm256_testz_si256(bxa, bxa)) return false;
        }
        return true;
    }

    friend bool operator!=(const Bitset& A, const Bitset& B) {
        for (size_t i = 0; i != kBlockCount; ++i) {
            Block bxa = _mm256_xor_si256(B.blockAt(i), A.blockAt(i));
            if (!_mm256_testz_si256(bxa, bxa)) return true;
        }
        return false;
    }

    friend bool operator<=(const Bitset& A, const Bitset& B) {
        for (size_t i = 0; i != kBlockCount; ++i)
            if (!_mm256_testc_si256(B.blockAt(i), A.blockAt(i)))
                return false;
        return true;
    }

    friend bool operator>=(const Bitset& A, const Bitset& B) {
        for (size_t i = 0; i != kBlockCount; ++i)
            if (!_mm256_testc_si256(A.blockAt(i), B.blockAt(i)))
                return false;
        return true;
    }

    friend bool operator<(const Bitset& A, const Bitset& B) {
        bool different = false;
        for (size_t i = 0; i != kBlockCount; ++i) {
            Block a = A.blockAt(i), b = B.blockAt(i);
            if (!_mm256_testc_si256(b, a)) return false;
            different |= !_mm256_testc_si256(a, b);
        }
        return different;
    }

    friend bool operator>(const Bitset& A, const Bitset& B) {
        bool different = false;
        for (size_t i = 0; i != kBlockCount; ++i) {
            Block a = A.blockAt(i), b = B.blockAt(i);
            if (!_mm256_testc_si256(a, b)) return false;
            different |= !_mm256_testc_si256(b, a);
        }
        return different;
    }

    size_t findFirstSet(size_t position) const {
        if (position >= kSize) return size_t(-1);
        size_t wordIndex = position / kWordSize, bitInWord = position % kWordSize;
        size_t blockIndex = wordIndex / kBlockSize, wordInBlock = wordIndex % kBlockSize;
        if (Word word = mData[wordIndex] & ~((1ULL << bitInWord) - 1)) {
            size_t result = wordIndex * kWordSize + size_t(__builtin_ctzll(word));
            return result < kSize ? result : size_t(-1);
        }
        while (++wordInBlock != kBlockSize) {
            size_t current = blockIndex * kBlockSize + wordInBlock;
            if (current >= kWordCount) break;
            if (Word word = mData[current]) {
                size_t result = current * kWordSize + __builtin_ctzll(word);
                return result < kSize ? result : size_t(-1);
            }
        }
        while (++blockIndex != kBlockCount) {
            Block mask = _mm256_set1_epi64x(-1), block = blockAt(blockIndex);
            if (_mm256_testz_si256(block, mask)) continue;
            for (size_t i = 0; i != kBlockSize; ++i) {
                size_t current = blockIndex * kBlockSize + i;
                if (current >= kWordCount) break;
                if (Word word = mData[current]) {
                    size_t result = current * kWordSize + __builtin_ctzll(word);
                    return result < kSize ? result : size_t(-1);
                }
            }
        }
        return size_t(-1);
    }

    size_t findFirstUnset(size_t position) const {
        if (position >= kSize) return size_t(-1);
        size_t wordIndex = position / kWordSize, bitInWord = position % kWordSize;
        size_t blockIndex = wordIndex / kBlockSize, wordInBlock = wordIndex % kBlockSize;
        if (Word word = ~mData[wordIndex] & ~((1ULL << bitInWord) - 1)) {
            size_t result = wordIndex * kWordSize + size_t(__builtin_ctzll(word));
            return result < kSize ? result : size_t(-1);
        }
        while (++wordInBlock != kBlockSize) {
            size_t current = blockIndex * kBlockSize + wordInBlock;
            if (current >= kWordCount) break;
            if (Word word = ~mData[current]) {
                size_t result = current * kWordSize + __builtin_ctzll(word);
                return result < kSize ? result : size_t(-1);
            }
        }
        while (++blockIndex != kBlockCount) {
            Block mask = _mm256_set1_epi64x(-1), block = blockAt(blockIndex);
            if (_mm256_testc_si256(block, mask)) continue;
            for (size_t i = 0; i != kBlockSize; ++i) {
                size_t current = blockIndex * kBlockSize + i;
                if (current >= kWordCount) break;
                if (Word word = ~mData[current]) {
                    size_t result = current * kWordSize + __builtin_ctzll(word);
                    return result < kSize ? result : size_t(-1);
                }
            }
        }
        return size_t(-1);
    }

    void set(size_t position, size_t length) {
        if (position + length > kSize || !length) return;
        size_t wordIndex = position / kWordSize, bitInWord = position % kWordSize;
        if (bitInWord) {
            size_t headLength = kWordSize - bitInWord < length ? kWordSize - bitInWord : length;
            mData[wordIndex++] |= ((1ULL << headLength) - 1) << bitInWord;
            length -= headLength;
        }
        for (const Block value = _mm256_set1_epi64x(-1); length >= sizeof(Block) * CHAR_BIT; ) {
            _mm256_storeu_si256((Block*)&mData[wordIndex], value);
            length -= sizeof(Block) * CHAR_BIT;
            wordIndex += kBlockSize;
        }
        while (length >= kWordSize) {
            mData[wordIndex++] = -1ULL;
            length -= kWordSize;
        }
        mData[wordIndex] |= (1ULL << length) - 1;
    }

    void unset(size_t position, size_t length) {
        if (position + length > kSize || !length) return;
        size_t wordIndex = position / kWordSize, bitInWord = position % kWordSize;
        if (bitInWord) {
            size_t headLength = kWordSize - bitInWord < length ? kWordSize - bitInWord : length;
            mData[wordIndex++] &= ~(((1ULL << headLength) - 1) << bitInWord);
            length -= headLength;
        }
        for (const Block value = _mm256_setzero_si256(); length >= sizeof(Block) * CHAR_BIT; ) {
            _mm256_storeu_si256((Block*)&mData[wordIndex], value);
            length -= sizeof(Block) * CHAR_BIT;
            wordIndex += kBlockSize;
        }
        while (length >= kWordSize) {
            mData[wordIndex++] = 0ULL;
            length -= kWordSize;
        }
        mData[wordIndex] &= ~((1ULL << length) - 1);
    }

    void flip(size_t position, size_t length) {
        if (position + length > kSize || !length) return;
        size_t wordIndex = position / kWordSize, bitInWord = position % kWordSize;
        if (bitInWord) {
            size_t headLength = kWordSize - bitInWord < length ? kWordSize - bitInWord : length;
            mData[wordIndex++] ^= ((1ULL << headLength) - 1) << bitInWord;
            length -= headLength;
        }
        for (const Block value = _mm256_set1_epi64x(-1); length >= sizeof(Block) * CHAR_BIT; ) {
            _mm256_storeu_si256((Block*)&mData[wordIndex], _mm256_xor_si256(_mm256_loadu_si256((const Block*)&mData[wordIndex]), value));
            length -= sizeof(Block) * CHAR_BIT;
            wordIndex += kBlockSize;
        }
        while (length >= kWordSize) {
            mData[wordIndex++] ^= -1ULL;
            length -= kWordSize;
        }
        mData[wordIndex] ^= (1ULL << length) - 1;
    }

    Bitset& operator<<=(size_t step) {
        if (step >= kSize) return unset(), *this;
        size_t wordShift = step / kWordSize, bitShift = step % kWordSize;
        if (!bitShift) {
            memmove(mData + wordShift, mData, (kWordCount - wordShift) * sizeof(Word));
            memset(mData, 0, wordShift * sizeof(Word));
        } else {
            size_t remaining = kWordCount - wordShift - 1;
            Word *destination = mData + kWordCount, *source = destination - wordShift;
            __m128i lShift = _mm_cvtsi32_si128(int(bitShift)), rShift = _mm_cvtsi32_si128(int(kWordSize - bitShift));
            while (remaining >= kBlockSize) {
                destination -= kBlockSize, source -= kBlockSize;
                Block low = _mm256_srl_epi64(_mm256_loadu_si256((const Block*)&source[-1]), rShift);
                Block high = _mm256_sll_epi64(_mm256_loadu_si256((const Block*)&source[0]), lShift);
                _mm256_storeu_si256((Block*)destination, _mm256_or_si256(low, high));
                remaining -= kBlockSize;
            }
            while (remaining) {
                --destination, --source;
                *destination = (source[0] << bitShift) | (source[-1] >> (kWordSize - bitShift));
                --remaining;
            }
            *--destination = *--source << bitShift;
            memset(mData, 0, wordShift * sizeof(Word));
        }
        trim();
        return *this;
    }

    Bitset& operator>>=(size_t step) {
        if (step >= kSize) return unset(), *this;
        size_t wordShift = step / kWordSize, bitShift = step % kWordSize;
        if (!bitShift) {
            memmove(mData, mData + wordShift, (kWordCount - wordShift) * sizeof(Word));
            memset(mData + kWordCount - wordShift, 0, wordShift * sizeof(Word));
        } else {
            size_t remaining = kWordCount - wordShift - 1;
            Word *destination = mData, *source = mData + wordShift;
            __m128i rShift = _mm_cvtsi32_si128(int(bitShift)), lShift = _mm_cvtsi32_si128(int(kWordSize - bitShift));
            while (remaining >= kBlockSize) {
                Block low = _mm256_srl_epi64(_mm256_loadu_si256((const Block*)&source[0]), rShift);
                Block high = _mm256_sll_epi64(_mm256_loadu_si256((const Block*)&source[1]), lShift);
                _mm256_storeu_si256((Block*)destination, _mm256_or_si256(low, high));
                destination += kBlockSize, source += kBlockSize;
                remaining -= kBlockSize;
            }
            while (remaining) {
                *destination = (source[0] >> bitShift) | (source[1] << (kWordSize - bitShift));
                ++destination, ++source;
                --remaining;
            }
            *destination = *source >> bitShift;
            memset(mData + kWordCount - wordShift, 0, wordShift * sizeof(Word));
        }
        trim();
        return *this;
    }

    Bitset operator<<(size_t step) const {
        Bitset result;
        if (step >= kSize) return result;
        size_t wordShift = step / kWordSize, bitShift = step % kWordSize;
        if (!bitShift) {
            memcpy(result.mData + wordShift, mData, (kWordCount - wordShift) * sizeof(Word));
        } else {
            size_t remaining = kWordCount - wordShift - 1;
            Word* destination = result.mData + kWordCount;
            const Word* source = mData + kWordCount - wordShift;
            __m128i lShift = _mm_cvtsi32_si128(int(bitShift)), rShift = _mm_cvtsi32_si128(int(kWordSize - bitShift));
            while (remaining >= kBlockSize) {
                destination -= kBlockSize, source -= kBlockSize;
                Block low = _mm256_srl_epi64(_mm256_loadu_si256((const Block*)&source[-1]), rShift);
                Block high = _mm256_sll_epi64(_mm256_loadu_si256((const Block*)&source[0]), lShift);
                _mm256_storeu_si256((Block*)destination, _mm256_or_si256(low, high));
                remaining -= kBlockSize;
            }
            while (remaining) {
                --destination, --source;
                *destination = (source[0] << bitShift) | (source[-1] >> (kWordSize - bitShift));
                --remaining;
            }
            *--destination = *--source << bitShift;
        }
        result.trim();
        return result;
    }

    Bitset operator>>(size_t step) const {
        Bitset result;
        if (step >= kSize) return result;
        size_t wordShift = step / kWordSize, bitShift = step % kWordSize;
        if (!bitShift) {
            memcpy(result.mData, mData + wordShift, (kWordCount - wordShift) * sizeof(Word));
        } else {
            size_t remaining = kWordCount - wordShift - 1;
            Word* destination = result.mData;
            const Word* source = mData + wordShift;
            __m128i rShift = _mm_cvtsi32_si128(int(bitShift)), lShift = _mm_cvtsi32_si128(int(kWordSize - bitShift));
            while (remaining >= kBlockSize) {
                Block low = _mm256_srl_epi64(_mm256_loadu_si256((const Block*)&source[0]), rShift);
                Block high = _mm256_sll_epi64(_mm256_loadu_si256((const Block*)&source[1]), lShift);
                _mm256_storeu_si256((Block*)destination, _mm256_or_si256(low, high));
                destination += kBlockSize, source += kBlockSize;
                remaining -= kBlockSize;
            }
            while (remaining) {
                *destination = (source[0] >> bitShift) | (source[1] << (kWordSize - bitShift));
                ++destination, ++source;
                --remaining;
            }
            *destination = *source >> bitShift;
        }
        result.trim();
        return result;
    }
};

#endif

:::

性能测试

为了展示我们的向量化 bitset 与一般的标量 bitset 的性能差距,我们进行了一组性能测试。细节如下:

操作类型 标量用时 向量用时 提升倍率 其它说明
逻辑运算 112.44 61.65 1.82 按位与以及一次赋值
关系查询 81.47 38.17 2.13 子集关系查询,构造数据使得前 \frac 78n 的数据完全相同
区间修改 17.88 17.44 1.02 区间赋值为 1,区间长度为 \frac 78n
状态查询 81.47 38.17 2.13 是否全 1,构造数据使得前 \frac 78n 的数据均为 1
位置查找 31.25 18.44 1.69 查找下一个 1 的位置,构造数据使得前 \frac 78n 的数据均为 0
左右移位 147.91 91.78 1.61 左移以及一次赋值
个数统计 53.07 37.90 1.40 随机数据
嵌套运算 507.10 108.58 4.67 嵌套四次按位与以及一次赋值

可以看到,即使在 -Ofast -march=native 的加持下,编译器对标量 for 的自动向量化依旧比不上我们手写的 AVX2 Intrinsics。逻辑运算的提升达到了 82\%,算得上极为显著的优化,子集关系与状态查询提升甚至超过了 100\%!区间修改由于大块连续内存写入时瓶颈转移到了内存带宽,向量化带来的优化不大。

而对于位置查找与左右移位这两个非对齐操作,向量化也带来了明显的提升,分别为 70\%60\%。值得一提的是个数统计操作,在 -std=c++20 -march=native 编译参数下 std::popcount 会被识别并编译为极快的 POPCNT 硬件指令,尽管如此我们的并行查表法依然凭借每次处理 256 位的超宽数据流,硬生生从硬件指令手里抢下了 40\% 的性能提升。

最为恐怖的是嵌套运算,足足达到了 4.67 的提升倍率!这就是表达式模板为我带来的自信!