全面超越标准库 bitset 手把手教程
masonxiong · · 科技·工程
省流:实现了速度约是
std::bitset两倍的向量化 bitset,代码在参考实现部分。
说在前面
在算法竞赛中,bitset 一直是常数优化的利器。无论是图论中的传递闭包,状态压缩动态规划中的可行性转移,还是字符串匹配中的 Shift-And 算法,使用 bitset 都能通过将
大部分人追求代码简洁,选择了标准库提供的 std::bitset。然而,随着现代 CPU 架构的演进和竞赛题目时限的日益内卷,标准库的 std::bitset 开始暴露出两个令人头疼的性能瓶颈:
-
SIMD 向量化利用率低下:现代 x86 CPU 普遍支持 AVX2 指令集,能够在一个时钟周期内处理
256 位的数据。但受限于编译器的自动向量化能力和标准库的保守实现,std::bitset的底层往往只能依赖标量指令(逐个64 位处理,甚至部分实现是32 位的),白白浪费了 CPU 庞大的并行吞吐能力。 -
嵌套运算的临时对象灾难:在处理如
A = (B & C) ^ (~D | E)这样的复杂逻辑表达式时,C++ 的运算符重载机制会为每一步二元运算生成一个完整的临时对象,并伴随着多次对内存的完整遍历。这不仅导致了极高的内存带宽开销,还严重破坏了 CPU 的缓存命中率。
为了打破这些桎梏,获得更高的效率,我们不能再依赖标准库。本文将带你从零开始,手写一个专为算法竞赛打造的极速静态 bitset。我们将直接介入底层,利用 AVX2 指令集实现真正的
接下来是教程。
记号约定
为了方便描述,我们定义:
- 位:bitset 存储的基本单位,对应一个二进制位。
- 字:位的压缩产物,对应一个
64 位无符号整数。 - 块:字的压缩产物,对应一个
256 位整数向量。
基础布局
要实现一个极致性能的 bitset,第一步是设计其内存布局。与 std::bitset 类似,我们使用 uint64_t(字)数组来存储位信息。但为了充分利用 AVX2 指令集,我们需要引入 __m256i(块)。
template <size_t kSize>
class Bitset {
protected:
typedef uint64_t Word;
typedef __m256i Block;
static const size_t kWordSize = sizeof(Word) * CHAR_BIT;
static const size_t kBlockSize = sizeof(Block) / sizeof(Word);
static const size_t kWordCount = (kSize + kWordSize - 1) / kWordSize;
static const size_t kBlockCount = (kWordCount + kBlockSize - 1) / kBlockSize;
alignas(Block) Word mData[kBlockCount * kBlockSize]{};
void trim() {
if (kSize % kWordSize) mData[kWordCount - 1] &= (1ULL << (kSize % kWordSize)) - 1;
if (kBlockCount * kBlockSize > kWordCount) memset(mData + kWordCount, 0, (kBlockCount * kBlockSize - kWordCount) * sizeof(Word));
}
void normalize() {
if (kSize % kWordSize) mData[kWordCount - 1] &= (1ULL << (kSize % kWordSize)) - 1;
if (kBlockCount * kBlockSize > kWordCount) memset(mData + kWordCount, 0, (kBlockCount * kBlockSize - kWordCount) * sizeof(Word));
}
inline Word wordAt(size_t position) const {
return mData[position];
}
inline Block blockAt(size_t position) const {
return _mm256_load_si256((const Block*)&mData[position * kBlockSize]);
}
inline void wordSet(size_t position, Word value) {
mData[position] = value;
}
inline void blockSet(size_t position, Block value) {
_mm256_store_si256((Block*)&mData[position * kBlockSize], value);
}
};
这里有几个关键的设计细节:
-
基本单位:
Word是64 位的标量单位,而Block是256 位的向量单位。我们预先计算好所需的Word数量和Block数量。注意,数组的实际分配大小是kBlockCount * kBlockSize,这是在末尾补满一个块,保证我们在进行 SIMD 批量操作时,永远不会发生内存越界,免去一些边界讨论。 -
内存对齐:这是向量化编程中极其重要的一环。我们使用
alignas(Block)这一 C++11 关键字强制mData数组在内存中按照Block的对齐要求——32 字节——对齐。有了这个保证,我们在后续代码中就可以安全地使用_mm256_load_si256和_mm256_store_si256这类要求内存对齐的指令。相比于非对齐的loadu/storeu,对齐的内存访问在某些微架构上能带来可观的性能提升。 -
边界清理:由于我们在末尾分配了一些闲置位,且某些批量操作会把超出
kSize范围的无效位也置为1 ,这会严重干扰后续的操作。因此,我们需要一个trim() / normalize()函数,在每次修改操作完成后,将超出kSize的无效位强制清零。具体来说,首先用位掩码清理最后一个有效Word中多余的高位;然后用memset将后面纯粹作为填充的Word全部清零。 -
字块访问:定义了几个函数来方便我们访问或修改指定的字或块。其中
_mm256_load_si256以及_mm256_store_si256是 AVX2 指令集中的指令,位于<immintrin.h>头文件中,可以查询 Intel Intrinsics Guide 获取详细信息。本文中所有用到的 AVX2 指令及其作用见下,其中等价代码仅作辅助理解使用:
-
Block _mm256_setzero_si256()- 指令作用:返回一个空
256 位向量。 - 等价代码:
Block result; memset(&result, 0, sizeof(Block)); return result;
- 指令作用:返回一个空
-
Block _mm256_load_si256(const Block* source)- 指令作用:加载
source指向的一个256 位向量。source的地址必须按照Block的对齐要求对齐。 - 等价代码:
return *source;
- 指令作用:加载
-
Block _mm256_loadu_si256(const Block* source)- 指令作用:加载
source指向的一个256 位向量。 - 等价代码:
return *source;
- 指令作用:加载
-
void _mm256_store_si256(Block* destination, Block data)- 指令作用:在
destination储存一个256 位向量。destination的地址必须按照Block的对齐要求对齐。 - 等价代码:
*destination = data;
- 指令作用:在
-
void _mm256_storeu_si256(Block* destination, Block data)- 指令作用:在
destination储存一个256 位向量。 - 等价代码:
*destination = data;
- 指令作用:在
-
int _mm256_testz_si256(Block A, Block B)- 指令作用:测试
A & B的结果是否全0 。 - 等价代码:
return (A & B) == 0;
- 指令作用:测试
-
int _mm256_testc_si256(Block A, Block B)- 指令作用:测试
B中为1 的位在A中是否为1 。 - 等价代码:
return (A & B) == B;
- 指令作用:测试
-
Block _mm256_set1_epi8(char value)- 指令作用:返回一个
256 位向量,其中每个8 位整数均被赋值为value。 - 等价代码:
Block result; std::fill_n(&result, sizeof(Block) / sizeof(char), value); return result;
- 指令作用:返回一个
-
Block _mm256_set1_epi64x(long long value)- 指令作用:返回一个
256 位向量,其中每个64 位整数均被赋值为value。 - 等价代码:
Block result; std::fill_n(&result, sizeof(Block) / sizeof(long long), value); return result;
- 指令作用:返回一个
-
Block _mm256_setr_epi8(char A31, char A30, ..., char A0)- 指令作用:返回一个
256 位向量,用参数列表逆序填充每个8 位整数。 - 等价代码:
Block result; std::vector<char> A{A31, A30, ..., A0}; memcpy(&result, A.data(), sizeof(Block)); return result;
- 指令作用:返回一个
-
Block _mm256_setr_epi64x(long long A3, long long A2, ..., long long A0)- 指令作用:返回一个
256 位向量,用参数列表逆序填充每个64 位整数。 - 等价代码:
Block result; std::vector<long long> A{A3, A2, ..., A0}; memcpy(&result, A.data(), sizeof(Block)); return result;
- 指令作用:返回一个
-
Block _mm256_add_epi8(Block A, Block B)- 指令作用:将两个
256 位向量中的每个8 位整数对位相加并返回结果。 - 等价代码:
Block result; for (size_t i = 0; i < sizeof(Block) / sizeof(char); ++i) ((char*)&result)[i] = ((char*)&A)[i] + ((char*)&B)[i]; return result;
- 指令作用:将两个
-
Block _mm256_add_epi64(Block A, Block B)- 指令作用:将两个
256 位向量中的每个64 位整数对位相加并返回结果。 - 等价代码:
Block result; for (size_t i = 0; i < sizeof(Block) / sizeof(long long); ++i) ((long long*)&result)[i] = ((long long*)&A)[i] + ((long long*)&B)[i]; return result;
- 指令作用:将两个
-
Block _mm256_and_si256(Block A, Block B)- 指令作用:将两个
256 位向量视作两个256 位整数,返回它们的按位与。 - 等价代码:
return A & B;
- 指令作用:将两个
-
Block _mm256_or_si256(Block A, Block B)- 指令作用:将两个
256 位向量视作两个256 位整数,返回它们的按位或。 - 等价代码:
return A | B;
- 指令作用:将两个
-
Block _mm256_xor_si256(Block A, Block B)- 指令作用:将两个
256 位向量视作两个256 位整数,返回它们的按位异或。 - 等价代码:
return A ^ B;
- 指令作用:将两个
-
Block _mm256_andnot_si256(Block A, Block B)- 指令作用:将两个
256 位向量视作两个256 位整数,返回它们的按位与非。 - 等价代码:
return ~A & B;
- 指令作用:将两个
-
Block _mm256_srli_epi16(Block A, int shift)- 指令作用:将
256 位向量中的每个16 位整数右移shift位。shift应当是常数。 - 等价代码:
Block result; for (size_t i = 0; i < sizeof(Block) / sizeof(short); ++i) ((short*)&result)[i] = ((short*)&A)[i] >> shift; return result;
- 指令作用:将
-
Block _mm256_slli_epi16(Block A, int shift)- 指令作用:将
256 位向量中的每个16 位整数左移shift位。shift应当是常数。 - 等价代码:
Block result; for (size_t i = 0; i < sizeof(Block) / sizeof(short); ++i) ((short*)&result)[i] = ((short*)&A)[i] << shift; return result;
- 指令作用:将
-
Block _mm256_srl_epi64(Block A, __m128i shift)- 指令作用:将
256 位向量中的每个64 位整数右移shift位。仅使用shift中的低32 位,可以用_mm256_cvtsi32_si128(int n)快速生成这样的移位用128 位向量。 - 等价代码:
Block result; for (size_t i = 0; i < sizeof(Block) / sizeof(long long); ++i) ((long long*)&result)[i] = ((long long*)&A)[i] >> int(shift[0]); return result;
- 指令作用:将
-
Block _mm256_sll_epi64(Block A, __m128i shift)- 指令作用:将
256 位向量中的每个64 位整数左移shift位。仅使用shift中的低32 位,可以用_mm256_cvtsi32_si128(int n)快速生成这样的移位用128 位向量。 - 等价代码:
Block result; for (size_t i = 0; i < sizeof(Block) / sizeof(long long); ++i) ((long long*)&result)[i] = ((long long*)&A)[i] << int(shift[0]); return result;
- 指令作用:将
-
Block _mm256_sad_epu8(Block A, Block B)- 指令作用:将两个
256 位向量中的每个8 位整数对位相减并取绝对值,然后将每8 个8 位的差的绝对值求和,存储到这8 个8 位对应的64 位整数中。 - 等价代码:
Block result; for (size_t i = 0; i < sizeof(Block) / sizeof(long long); ++i) { ((long long*)&result)[i] = 0; for (size_t j = 0; j < sizeof(long long) / sizeof(char); ++j) { ((long long*)&result)[i] += abs( ((char*)&A)[i * sizeof(long long) + j] - ((char*)&B)[i * sizeof(long long) + j] ); } } return result;
- 指令作用:将两个
-
Block _mm256_shuffle_epi8(Block A, Block B)- 指令作用:将
A视为包含32 个8 位整数的查找表,将B视为索引向量。对于B中的每一个8 位整数,如果它的最高位为1 ,则结果中对应的8 位整数置为0 ;否则,取其低4 位作为索引,从A的对应128 位半区中取出对应的字节放入结果中。B的低16 字节只能索引A的低16 字节,高 16 字节只能索引A的高16 字节。 - 等价代码:
Block result; for (size_t i = 0; i < sizeof(Block) / sizeof(char); ++i) { if (((char*)&B)[i] & 0x80) { ((char*)&result)[i] = 0; } else { size_t lane = sizeof(Block) / sizeof(char) / 2; size_t index = ((char*)&B)[i] & 0x0F, offset = i / lane * lane; ((char*)&result)[i] = ((char*)&A)[index + offset]; } } return result;
- 指令作用:将
简单操作向量化
有了上面这些指令,我们就可以向量化一些简单的操作了。
逻辑运算
即按位与、或、异或、差。这是最容易向量化的部分了,两个 bitset 同时遍历每个块,然后通过 _mm256_and_si256、_mm256_or_si256、_mm256_xor_si256、_mm256_andnot_si256 计算结果即可。注意集合差对应的是反的与非。
friend Bitset operator&(const Bitset& A, const Bitset& B) {
Bitset result;
for (size_t i = 0; i != kBlockCount; ++i)
result.blockSet(i, _mm256_and_si256(B.blockAt(i), A.blockAt(i)));
return result;
}
friend Bitset operator|(const Bitset& A, const Bitset& B) {
Bitset result;
for (size_t i = 0; i != kBlockCount; ++i)
result.blockSet(i, _mm256_or_si256(B.blockAt(i), A.blockAt(i)));
return result;
}
friend Bitset operator^(const Bitset& A, const Bitset& B) {
Bitset result;
for (size_t i = 0; i != kBlockCount; ++i)
result.blockSet(i, _mm256_xor_si256(B.blockAt(i), A.blockAt(i)));
return result;
}
friend Bitset operator-(const Bitset& A, const Bitset& B) {
Bitset result;
for (size_t i = 0; i != kBlockCount; ++i)
result.blockSet(i, _mm256_andnot_si256(B.blockAt(i), A.blockAt(i)));
return result;
}
关系查询
判断两个 bitset 的关系,注意这里的 >=、>、<=、< 分别指的是包含、真包含、子集、真子集,也就是
相等判断可以使用 memcmp 实现,但是由于我们保证内存对齐所以手写向量化可能效率更高,具体来说我们求两个 bitset 块对位异或,然后用 _mm256_testz_si256 判断其是否全 _mm256_testc_si256,真包含在包含的基础上判断不等即可。
friend bool operator==(const Bitset& A, const Bitset& B) {
for (size_t i = 0; i != kBlockCount; ++i) {
Block bxa = _mm256_xor_si256(B.blockAt(i), A.blockAt(i));
if (!_mm256_testz_si256(bxa, bxa)) return false;
}
return true;
}
friend bool operator!=(const Bitset& A, const Bitset& B) {
for (size_t i = 0; i != kBlockCount; ++i) {
Block bxa = _mm256_xor_si256(B.blockAt(i), A.blockAt(i));
if (!_mm256_testz_si256(bxa, bxa)) return true;
}
return false;
}
friend bool operator<=(const Bitset& A, const Bitset& B) {
for (size_t i = 0; i != kBlockCount; ++i)
if (!_mm256_testc_si256(B.blockAt(i), A.blockAt(i)))
return false;
return true;
}
friend bool operator>=(const Bitset& A, const Bitset& B) {
for (size_t i = 0; i != kBlockCount; ++i)
if (!_mm256_testc_si256(A.blockAt(i), B.blockAt(i)))
return false;
return true;
}
friend bool operator<(const Bitset& A, const Bitset& B) {
bool different = false;
for (size_t i = 0; i != kBlockCount; ++i) {
Block a = A.blockAt(i), b = B.blockAt(i);
if (!_mm256_testc_si256(b, a)) return false;
different |= !_mm256_testc_si256(a, b);
}
return different;
}
friend bool operator>(const Bitset& A, const Bitset& B) {
bool different = false;
for (size_t i = 0; i != kBlockCount; ++i) {
Block a = A.blockAt(i), b = B.blockAt(i);
if (!_mm256_testc_si256(a, b)) return false;
different |= !_mm256_testc_si256(b, a);
}
return different;
}
区间修改
区间赋值,区间翻转。如果逐位操作,显然太慢了。标准的向量化区间操作通常采用“头部——躯干——尾部”的三段式处理策略。
首先是头部,position 可能并不对齐到 _mm256_storeu_si256 每次写入 _mm256_xor_si256 每次翻转
最后剩下的不足
void set(size_t position, size_t length) {
if (position + length > kSize || !length) return;
size_t wordIndex = position / kWordSize, bitInWord = position % kWordSize;
if (bitInWord) {
size_t headLength = kWordSize - bitInWord < length ? kWordSize - bitInWord : length;
mData[wordIndex++] |= ((1ULL << headLength) - 1) << bitInWord;
length -= headLength;
}
for (const Block value = _mm256_set1_epi64x(-1); length >= sizeof(Block) * CHAR_BIT; ) {
_mm256_storeu_si256((Block*)&mData[wordIndex], value);
length -= sizeof(Block) * CHAR_BIT;
wordIndex += kBlockSize;
}
while (length >= kWordSize) {
mData[wordIndex++] = -1ULL;
length -= kWordSize;
}
mData[wordIndex] |= (1ULL << length) - 1;
}
void unset(size_t position, size_t length) {
if (position + length > kSize || !length) return;
size_t wordIndex = position / kWordSize, bitInWord = position % kWordSize;
if (bitInWord) {
size_t headLength = kWordSize - bitInWord < length ? kWordSize - bitInWord : length;
mData[wordIndex++] &= ~(((1ULL << headLength) - 1) << bitInWord);
length -= headLength;
}
for (const Block value = _mm256_setzero_si256(); length >= sizeof(Block) * CHAR_BIT; ) {
_mm256_storeu_si256((Block*)&mData[wordIndex], value);
length -= sizeof(Block) * CHAR_BIT;
wordIndex += kBlockSize;
}
while (length >= kWordSize) {
mData[wordIndex++] = 0ULL;
length -= kWordSize;
}
mData[wordIndex] &= ~((1ULL << length) - 1);
}
void flip(size_t position, size_t length) {
if (position + length > kSize || !length) return;
size_t wordIndex = position / kWordSize, bitInWord = position % kWordSize;
if (bitInWord) {
size_t headLength = kWordSize - bitInWord < length ? kWordSize - bitInWord : length;
mData[wordIndex++] ^= ((1ULL << headLength) - 1) << bitInWord;
length -= headLength;
}
for (const Block value = _mm256_set1_epi64x(-1); length >= sizeof(Block) * CHAR_BIT; ) {
_mm256_storeu_si256((Block*)&mData[wordIndex], _mm256_xor_si256(_mm256_loadu_si256((const Block*)&mData[wordIndex]), value));
length -= sizeof(Block) * CHAR_BIT;
wordIndex += kBlockSize;
}
while (length >= kWordSize) {
mData[wordIndex++] ^= -1ULL;
length -= kWordSize;
}
mData[wordIndex] ^= (1ULL << length) - 1;
}
状态查询
查询 bitset 是否全 _mm256_testz_si256(A, B) 可以判断 (A & B) == 0,如果我们将 B 置为全 A == 0,可以利用其实现 none 以及 any。至于 all 的判断,可以用 _mm256_testc_si256(A, B) 判断 (~A & B) == 0,如果我们将 B 置为全 A 是全
但是请一定要注意特判末尾填充位!如果 kSize 不是 trim() 的作用下永远是 _mm256_testc_si256 来判断 all(),由于填充位是
bool none() const {
const size_t fullBlocks = kSize / (sizeof(Block) * CHAR_BIT);
for (size_t i = 0; i != fullBlocks; ++i) {
Block mask = _mm256_set1_epi8(-1), block = blockAt(i);
if (!_mm256_testz_si256(block, mask)) return false;
}
if (size_t remaining = kSize % (sizeof(Block) * CHAR_BIT)) {
Word maskArray[kBlockSize] = {};
for (size_t i = 0; i != kBlockSize; ++i) {
if (remaining >= kWordSize) {
maskArray[i] = ~0ULL;
remaining -= kWordSize;
} else {
maskArray[i] = (1ULL << remaining) - 1;
break;
}
}
Block mask = _mm256_loadu_si256((const Block*)maskArray), block = blockAt(fullBlocks);
if (!_mm256_testz_si256(block, mask)) return false;
}
return true;
}
bool any() const {
const size_t fullBlocks = kSize / (sizeof(Block) * CHAR_BIT);
for (size_t i = 0; i != fullBlocks; ++i) {
Block mask = _mm256_set1_epi8(-1), block = blockAt(i);
if (!_mm256_testz_si256(block, mask)) return true;
}
if (size_t remaining = kSize % (sizeof(Block) * CHAR_BIT)) {
Word maskArray[kBlockSize] = {};
for (size_t i = 0; i != kBlockSize; ++i) {
if (remaining >= kWordSize) {
maskArray[i] = ~0ULL;
remaining -= kWordSize;
} else {
maskArray[i] = (1ULL << remaining) - 1;
break;
}
}
Block mask = _mm256_loadu_si256((const Block*)maskArray), block = blockAt(fullBlocks);
if (!_mm256_testz_si256(block, mask)) return true;
}
return false;
}
bool all() const {
const size_t fullBlocks = kSize / (sizeof(Block) * CHAR_BIT);
for (size_t i = 0; i != fullBlocks; ++i) {
Block mask = _mm256_set1_epi8(-1), block = blockAt(i);
if (!_mm256_testc_si256(block, mask)) return false;
}
if (size_t remaining = kSize % (sizeof(Block) * CHAR_BIT)) {
Word maskArray[kBlockSize] = {};
for (size_t i = 0; i != kBlockSize; ++i) {
if (remaining >= kWordSize) {
maskArray[i] = ~0ULL;
remaining -= kWordSize;
} else {
maskArray[i] = (1ULL << remaining) - 1;
break;
}
}
Block mask = _mm256_loadu_si256((const Block*)maskArray), block = blockAt(fullBlocks);
if (!_mm256_testc_si256(block, mask)) return false;
}
return true;
}
位置查找
以查找下一个设置为 std::bitset 的 _Find_next 的实现方式,并加以向量化。
首先找到查找位置所在的字,通过掩码消掉该查找位置之前的所有位,然后检查其是否全 __builtin_ctzll 找到第一个设置为 none() 的实现方式使用 _mm256_testz_si256 快速判断当前块是否全 __builtin_ctzll 找到第一个设置为
查
size_t findFirstSet(size_t position) const {
if (position >= kSize) return size_t(-1);
size_t wordIndex = position / kWordSize, bitInWord = position % kWordSize;
size_t blockIndex = wordIndex / kBlockSize, wordInBlock = wordIndex % kBlockSize;
if (Word word = mData[wordIndex] & ~((1ULL << bitInWord) - 1)) {
size_t result = wordIndex * kWordSize + size_t(__builtin_ctzll(word));
return result < kSize ? result : size_t(-1);
}
while (++wordInBlock != kBlockSize) {
size_t current = blockIndex * kBlockSize + wordInBlock;
if (current >= kWordCount) break;
if (Word word = mData[current]) {
size_t result = current * kWordSize + __builtin_ctzll(word);
return result < kSize ? result : size_t(-1);
}
}
while (++blockIndex != kBlockCount) {
Block mask = _mm256_set1_epi64x(-1), block = blockAt(blockIndex);
if (_mm256_testz_si256(block, mask)) continue;
for (size_t i = 0; i != kBlockSize; ++i) {
size_t current = blockIndex * kBlockSize + i;
if (current >= kWordCount) break;
if (Word word = mData[current]) {
size_t result = current * kWordSize + __builtin_ctzll(word);
return result < kSize ? result : size_t(-1);
}
}
}
return size_t(-1);
}
size_t findFirstUnset(size_t position) const {
if (position >= kSize) return size_t(-1);
size_t wordIndex = position / kWordSize, bitInWord = position % kWordSize;
size_t blockIndex = wordIndex / kBlockSize, wordInBlock = wordIndex % kBlockSize;
if (Word word = ~mData[wordIndex] & ~((1ULL << bitInWord) - 1)) {
size_t result = wordIndex * kWordSize + size_t(__builtin_ctzll(word));
return result < kSize ? result : size_t(-1);
}
while (++wordInBlock != kBlockSize) {
size_t current = blockIndex * kBlockSize + wordInBlock;
if (current >= kWordCount) break;
if (Word word = ~mData[current]) {
size_t result = current * kWordSize + __builtin_ctzll(word);
return result < kSize ? result : size_t(-1);
}
}
while (++blockIndex != kBlockCount) {
Block mask = _mm256_set1_epi64x(-1), block = blockAt(blockIndex);
if (_mm256_testc_si256(block, mask)) continue;
for (size_t i = 0; i != kBlockSize; ++i) {
size_t current = blockIndex * kBlockSize + i;
if (current >= kWordCount) break;
if (Word word = ~mData[current]) {
size_t result = current * kWordSize + __builtin_ctzll(word);
return result < kSize ? result : size_t(-1);
}
}
}
return size_t(-1);
}
复杂操作向量化
部分操作向量化较为复杂,需要一定的思考。
左右移位
以左移为例。这是向量化 bitset 最难啃的骨头,为什么?因为 AVX2 没有提供跨 _mm256_slli_epi64 只能让 step 拆分为 wordShift = step / kWordSize 整字偏移量,以及 bitShift = step % kWordSize 位偏移量。然后分类讨论:
首先是 bitShift == 0 的情况,也就是整字偏移。这种情况下我们根本不需要位运算,直接使用 memmove 搬运内存即可。注意由于复制源区间和目标区间存在重叠,所以这里不能使用 memcpy。最后用 memset 把移动掉的位置清零即可。
当存在位偏移时,我们做的事情是:将当前字左移 bitShift 位,并提取相邻低地址字溢出的高 kWordSize - bitShift 位,将两者拼接,也就是或起来。为了向量化这个过程,我们使用非对齐加载 _mm256_loadu_si256,故意错开一个字的位置加载相邻的数据块,然后分别进行向量左移和右移,最后拼接。左右移使用 _mm256_sll_epi64 和 _mm256_srl_epi64 实现,通过 _mm_cvtsi32_si128 将偏移量 bitShift 转成要求的 __m128i 形式传进去。
需要注意的是,左移必须从高地址向低地址逆向遍历,右移必须从低向高遍历,这是为了防止源数据在被读取前就被覆盖。同时由于我们错开了一个字,所以我们需要使用非对齐加载。在处理完所有块之后,剩余不足整块的若干个字,我们回退到普通的标量循环进行拼接即可。最后别忘了调用 trim() 清理边界。
Bitset& operator<<=(size_t step) {
if (step >= kSize) return unset(), *this;
size_t wordShift = step / kWordSize, bitShift = step % kWordSize;
if (!bitShift) {
memmove(mData + wordShift, mData, (kWordCount - wordShift) * sizeof(Word));
memset(mData, 0, wordShift * sizeof(Word));
} else {
size_t remaining = kWordCount - wordShift - 1;
Word *destination = mData + kWordCount, *source = destination - wordShift;
__m128i lShift = _mm_cvtsi32_si128(int(bitShift)), rShift = _mm_cvtsi32_si128(int(kWordSize - bitShift));
while (remaining >= kBlockSize) {
destination -= kBlockSize, source -= kBlockSize;
Block low = _mm256_srl_epi64(_mm256_loadu_si256((const Block*)&source[-1]), rShift);
Block high = _mm256_sll_epi64(_mm256_loadu_si256((const Block*)&source[0]), lShift);
_mm256_storeu_si256((Block*)destination, _mm256_or_si256(low, high));
remaining -= kBlockSize;
}
while (remaining) {
--destination, --source;
*destination = (source[0] << bitShift) | (source[-1] >> (kWordSize - bitShift));
--remaining;
}
*--destination = *--source << bitShift;
memset(mData, 0, wordShift * sizeof(Word));
}
trim();
return *this;
}
Bitset& operator>>=(size_t step) {
if (step >= kSize) return unset(), *this;
size_t wordShift = step / kWordSize, bitShift = step % kWordSize;
if (!bitShift) {
memmove(mData, mData + wordShift, (kWordCount - wordShift) * sizeof(Word));
memset(mData + kWordCount - wordShift, 0, wordShift * sizeof(Word));
} else {
size_t remaining = kWordCount - wordShift - 1;
Word *destination = mData, *source = mData + wordShift;
__m128i rShift = _mm_cvtsi32_si128(int(bitShift)), lShift = _mm_cvtsi32_si128(int(kWordSize - bitShift));
while (remaining >= kBlockSize) {
Block low = _mm256_srl_epi64(_mm256_loadu_si256((const Block*)&source[0]), rShift);
Block high = _mm256_sll_epi64(_mm256_loadu_si256((const Block*)&source[1]), lShift);
_mm256_storeu_si256((Block*)destination, _mm256_or_si256(low, high));
destination += kBlockSize, source += kBlockSize;
remaining -= kBlockSize;
}
while (remaining) {
*destination = (source[0] >> bitShift) | (source[1] << (kWordSize - bitShift));
++destination, ++source;
--remaining;
}
*destination = *source >> bitShift;
memset(mData + kWordCount - wordShift, 0, wordShift * sizeof(Word));
}
trim();
return *this;
}
个数统计
标量 bitset 的 popcnt 方法是对每个字调用 __builtin_popcountll 并求和,可惜 AVX2 指令集中并没有现成的 popcnt 指令。类比常规 popcnt 的查表法,联想到 _mm256_shuffle_epi8 的查表功能,以此为突破口尝试实现。
首先 _mm256_shuffle_epi8 的查表是每个 table,通过掩码消去 block 中每个 _mm256_shuffle_epi8(table, block) 得到的结果 low 的每个 block 中的对应 _mm256_srli_epi16 将 block 右移 high 就存储了 block 中的对应
得到 low 和 high 之后,使用 _mm256_add_epi8 将两者对位相加,所得结果 sum 的每个 block 中对应的 _mm256_sad_epu8 实现的 sum 和一个全 _mm256_sad_epu8(sum, _mm256_setzero_si256()),所得结果的每个 sum 中对应
接着,我们开一个向量 result 作为累加器。每当我们得到一个块的 popcnt 结果向量后,使用 _mm256_add_epi64 将结果向量对位累加到 result 上。最后将 result 用 _mm256_storeu_si256 存到
size_t count() const {
const Block mask = _mm256_set1_epi8(0x0F);
const Block table = _mm256_setr_epi8(
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4
);
Block result = _mm256_setzero_si256();
for (size_t i = 0; i != kBlockCount; ++i) {
Block block = blockAt(i);
Block low = _mm256_shuffle_epi8(table, _mm256_and_si256(block, mask));
Block high = _mm256_shuffle_epi8(table, _mm256_and_si256(_mm256_srli_epi16(block, 4), mask));
result = _mm256_add_epi64(result, _mm256_sad_epu8(_mm256_add_epi8(low, high), _mm256_setzero_si256()));
}
Word parts[kBlockSize];
_mm256_storeu_si256((Block*)parts, result);
return size_t(parts[0] + parts[1] + parts[2] + parts[3]);
}
size_t popcnt() const {
const Block mask = _mm256_set1_epi8(0x0F);
const Block table = _mm256_setr_epi8(
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4
);
Block result = _mm256_setzero_si256();
for (size_t i = 0; i != kBlockCount; ++i) {
Block block = blockAt(i);
Block low = _mm256_shuffle_epi8(table, _mm256_and_si256(block, mask));
Block high = _mm256_shuffle_epi8(table, _mm256_and_si256(_mm256_srli_epi16(block, 4), mask));
result = _mm256_add_epi64(result, _mm256_sad_epu8(_mm256_add_epi8(low, high), _mm256_setzero_si256()));
}
Word parts[kBlockSize];
_mm256_storeu_si256((Block*)parts, result);
return size_t(parts[0] + parts[1] + parts[2] + parts[3]);
}
嵌套运算
以上实现的向量化 Bitset 在单一操作上的表现十分优秀,但是这个写法在类似 A = ~B & (C | D) 的嵌套表达式上会暴露出严重的性能缺陷:
- 临时对象泛滥:计算
~B会生成一个临时 bitset,计算C | D会生成第二个,最后做&又会生成第三个。 - 多次遍历内存:上述每一步都会触发一次对整个底层数组的完整遍历,导致极高的内存带宽开销和缓存丢失。
如果手写,我们可以通过循环融合,也就是手写 A[i] = ~B[i] & (C[i] | D[i]) 来避免这一问题,实现单趟遍历。但每次都手写显然违背了封装的初衷。有没有办法让编译器自动帮我们完成这种“循环融合”呢?答案就是 C++ 模板元编程的黑魔法——表达式模板。
表达式模板的核心思想是:重载运算符时不进行实际计算,而是返回一个轻量级的对象,用于在编译期记录表达式的抽象语法树。直到最终赋值给目标对象时,才触发真正的计算。其中语法树是一棵二叉树,其叶子节点代表一个 bitset,非叶子节点有一个二元运算表示“此节点代表的 bitset 为其左右儿子代表的 bitset 做此二元运算后的结果”。对一个表达式构建出语法树后,求出根节点代表的 bitset 即为此表达式的值。
首先我们定义一个基类 Expression,所有继承自此类型的类型均被视为表达式树上的节点。同时我们在基类提供一个向其后代转化的方法。这里利用 CRTP 即奇异递归模板模式来实现静态多态,避免虚函数带来的运行时开销。我们的 Bitset 类也会继承自 Expression<Bitset<kSize>>,使其成为语法树的叶子节点。接下来,我们为每一种运算定义一个节点结构体:
template <typename E>
struct Expression {
inline const E& operator()() const {
return static_cast<const E&>(*this);
}
};
template <typename E>
struct BitsetFlip : Expression<BitsetFlip<E>> {
const E& mE;
BitsetFlip(const E& e) : mE{e} {}
inline __m256i blockAt(size_t i) const {
return _mm256_xor_si256(mE.blockAt(i), _mm256_set1_epi64x(-1));
}
};
template <typename E>
struct BitsetNot : Expression<BitsetNot<E>> {
const E& mE;
BitsetNot(const E& e) : mE{e} {}
inline __m256i blockAt(size_t i) const {
return _mm256_xor_si256(mE.blockAt(i), _mm256_set1_epi64x(-1));
}
};
template <typename XE, typename YE>
struct BitsetAnd : Expression<BitsetAnd<XE, YE>> {
const XE& mXE;
const YE& mYE;
BitsetAnd(const XE& xE, const YE& yE) : mXE{xE}, mYE{yE} {}
inline __m256i blockAt(size_t i) const {
return _mm256_and_si256(mYE.blockAt(i), mXE.blockAt(i));
}
};
template <typename XE, typename YE>
struct BitsetOr : Expression<BitsetOr<XE, YE>> {
const XE& mXE;
const YE& mYE;
BitsetOr(const XE& xE, const YE& yE) : mXE{xE}, mYE{yE} {}
inline __m256i blockAt(size_t i) const {
return _mm256_or_si256(mYE.blockAt(i), mXE.blockAt(i));
}
};
template <typename XE, typename YE>
struct BitsetXor : Expression<BitsetXor<XE, YE>> {
const XE& mXE;
const YE& mYE;
BitsetXor(const XE& xE, const YE& yE) : mXE{xE}, mYE{yE} {}
inline __m256i blockAt(size_t i) const {
return _mm256_xor_si256(mYE.blockAt(i), mXE.blockAt(i));
}
};
template <typename XE, typename YE>
struct BitsetAnt : Expression<BitsetAnt<XE, YE>> {
const XE& mXE;
const YE& mYE;
BitsetAnt(const XE& xE, const YE& yE) : mXE{xE}, mYE{yE} {}
inline __m256i blockAt(size_t i) const {
return _mm256_andnot_si256(mYE.blockAt(i), mXE.blockAt(i));
}
};
注意这些运算节点结构体内部只保留了左右子节点的常引用。blockAt(i) 定义了实际的计算逻辑,但是此时并未执行。接下来我们重载全局的运算符,让它们返回语法树节点,而不是具体的 bitset:
template <typename E>
inline BitsetFlip<E> operator~(const Expression<E>& e) {
return BitsetFlip<E>(e());
}
template <typename E>
inline BitsetNot<E> operator!(const Expression<E>& e) {
return BitsetNot<E>(e());
}
template <typename XE, typename YE>
inline BitsetAnd<XE, YE> operator&(const Expression<XE>& xE, const Expression<YE>& yE) {
return BitsetAnd<XE, YE>(xE(), yE());
}
template <typename XE, typename YE>
inline BitsetOr<XE, YE> operator|(const Expression<XE>& xE, const Expression<YE>& yE) {
return BitsetOr<XE, YE>(xE(), yE());
}
template <typename XE, typename YE>
inline BitsetXor<XE, YE> operator^(const Expression<XE>& xE, const Expression<YE>& yE) {
return BitsetXor<XE, YE>(xE(), yE());
}
template <typename XE, typename YE>
inline BitsetAnt<XE, YE> operator-(const Expression<XE>& xE, const Expression<YE>& yE) {
return BitsetAnt<XE, YE>(xE(), yE());
}
最后,我们在 bitset 中提供接受 Expression 的方法。
Bitset(const Bitset&) = default;
Bitset(Bitset&&) = default;
template <typename E>
Bitset(const Expression<E>& e) {
for (size_t i = 0; i != kBlockCount; ++i)
blockSet(i, e().blockAt(i));
trim();
}
Bitset& operator=(const Bitset&) = default;
Bitset& operator=(Bitset&&) = default;
template <typename E>
Bitset& operator=(const Expression<E>& e) {
for (size_t i = 0; i != kBlockCount; ++i)
blockSet(i, e().blockAt(i));
trim();
return *this;
}
接下来我们以 A = ~B & (C | D) 为例讲解语法树的具体流程。
~B调用operator~,得到一个BitsetFlip对象。无整体赋值。(C | D)调用operator|,得到一个BitsetOr对象。无整体赋值。~B & (C | D)调用operator&,得到一个BitsetAnd对象。无整体赋值。A = ~B & (C | D)调用operator=:- 遍历每个块,调用
BitsetAnd的blockAt。 BitsetAnd的blockAt递归调用其左右子节点BitsetFlip和BitsetOr的blockAt。BitsetFlip的blockAt递归调用其子节点B的blockAt得到具体的值。BitsetOr的blockAt递归调用其左右子节点C和D的blockAt得到具体的值。
- 遍历每个块,调用
由于所有的 blockAt 均被标记为 inline,所以这些调用会被全部内联!最终生成的代码等价于:
for (size_t i = 0; i != kBlockCount; ++i) {
A.blockSet(i, _mm256_and_si256(
_mm256_xor_si256(B.blockAt(i), _mm256_set1_epi64x(-1)),
_mm256_or_si256(C.blockAt(i), D.blockAt(i))
));
}
这完美地实现了循环融合!没有临时对象,只有一次内存遍历,所有的计算被融合进了一条紧凑的向量化指令流中。这是一个非常优雅且高效的解决方案。但是请注意下面这段代码:
auto foo() {
Bitset<16> result;
return ~result;
}
int main() {
auto tmp = foo();
Bitset<16> bs = tmp;
}
程序会爆掉。怎么会这样?我们来分析一下。首先 ~result 的类型是 BitsetFlip,于是 foo 的返回值被 auto 推断为 BitsetFlip,所以 tmp 也被 auto 推断为 BitsetFlip 类型。但是,此时 tmp 保存的引用是 result 的,而 result 的生命周期已经结束了,tmp 保存的是悬垂引用,那么下面 bs = tmp 便成为了 UB。
那有人可能会想到一种解决方案:“把语法树节点的拷贝构造、移动构造、拷贝赋值、移动赋值全部删除,这样就不能写 auto tmp = foo() 了!”这种写法有一定的道理,但是 C++17 引入了强制拷贝消除,而 foo() 为纯右值,编译器会在 tmp 的位置上原位构造 BitsetFlip 对象,完全不需要复制构造和移动构造函数,因而这种方案行不通。
所以请不要在 Bitset 相关的表达式中使用 auto 进行类型推断,除非你非常清楚你在做什么。
参考实现
见云剪贴板。
:::info[或者这里]
#ifndef BITSET_H
#define BITSET_H 202605L
#include <stdint.h>
#include <string.h>
#include <immintrin.h>
template <typename E>
struct Expression {
inline const E& operator()() const {
return static_cast<const E&>(*this);
}
};
template <typename E>
struct BitsetFlip : Expression<BitsetFlip<E>> {
const E& mE;
BitsetFlip(const E& e) : mE{e} {}
inline __m256i blockAt(size_t i) const {
return _mm256_xor_si256(mE.blockAt(i), _mm256_set1_epi64x(-1));
}
};
template <typename E>
struct BitsetNot : Expression<BitsetNot<E>> {
const E& mE;
BitsetNot(const E& e) : mE{e} {}
inline __m256i blockAt(size_t i) const {
return _mm256_xor_si256(mE.blockAt(i), _mm256_set1_epi64x(-1));
}
};
template <typename XE, typename YE>
struct BitsetAnd : Expression<BitsetAnd<XE, YE>> {
const XE& mXE;
const YE& mYE;
BitsetAnd(const XE& xE, const YE& yE) : mXE{xE}, mYE{yE} {}
inline __m256i blockAt(size_t i) const {
return _mm256_and_si256(mYE.blockAt(i), mXE.blockAt(i));
}
};
template <typename XE, typename YE>
struct BitsetOr : Expression<BitsetOr<XE, YE>> {
const XE& mXE;
const YE& mYE;
BitsetOr(const XE& xE, const YE& yE) : mXE{xE}, mYE{yE} {}
inline __m256i blockAt(size_t i) const {
return _mm256_or_si256(mYE.blockAt(i), mXE.blockAt(i));
}
};
template <typename XE, typename YE>
struct BitsetXor : Expression<BitsetXor<XE, YE>> {
const XE& mXE;
const YE& mYE;
BitsetXor(const XE& xE, const YE& yE) : mXE{xE}, mYE{yE} {}
inline __m256i blockAt(size_t i) const {
return _mm256_xor_si256(mYE.blockAt(i), mXE.blockAt(i));
}
};
template <typename XE, typename YE>
struct BitsetAnt : Expression<BitsetAnt<XE, YE>> {
const XE& mXE;
const YE& mYE;
BitsetAnt(const XE& xE, const YE& yE) : mXE{xE}, mYE{yE} {}
inline __m256i blockAt(size_t i) const {
return _mm256_andnot_si256(mYE.blockAt(i), mXE.blockAt(i));
}
};
template <typename E>
inline BitsetFlip<E> operator~(const Expression<E>& e) {
return BitsetFlip<E>(e());
}
template <typename E>
inline BitsetNot<E> operator!(const Expression<E>& e) {
return BitsetNot<E>(e());
}
template <typename XE, typename YE>
inline BitsetAnd<XE, YE> operator&(const Expression<XE>& xE, const Expression<YE>& yE) {
return BitsetAnd<XE, YE>(xE(), yE());
}
template <typename XE, typename YE>
inline BitsetOr<XE, YE> operator|(const Expression<XE>& xE, const Expression<YE>& yE) {
return BitsetOr<XE, YE>(xE(), yE());
}
template <typename XE, typename YE>
inline BitsetXor<XE, YE> operator^(const Expression<XE>& xE, const Expression<YE>& yE) {
return BitsetXor<XE, YE>(xE(), yE());
}
template <typename XE, typename YE>
inline BitsetAnt<XE, YE> operator-(const Expression<XE>& xE, const Expression<YE>& yE) {
return BitsetAnt<XE, YE>(xE(), yE());
}
template <size_t kSize>
class Bitset : public Expression<Bitset<kSize>> {
protected:
typedef uint64_t Word;
typedef __m256i Block;
static const size_t kWordSize = sizeof(Word) * CHAR_BIT;
static const size_t kBlockSize = sizeof(Block) / sizeof(Word);
static const size_t kWordCount = (kSize + kWordSize - 1) / kWordSize;
static const size_t kBlockCount = (kWordCount + kBlockSize - 1) / kBlockSize;
alignas(Block) Word mData[kBlockCount * kBlockSize]{};
public:
Bitset() = default;
~Bitset() = default;
inline Word wordAt(size_t position) const {
return mData[position];
}
inline Block blockAt(size_t position) const {
return _mm256_load_si256((const Block*)&mData[position * kBlockSize]);
}
protected:
inline void wordSet(size_t position, Word value) {
mData[position] = value;
}
inline void blockSet(size_t position, Block value) {
_mm256_store_si256((Block*)&mData[position * kBlockSize], value);
}
void trim() {
if (kSize % kWordSize) mData[kWordCount - 1] &= (1ULL << (kSize % kWordSize)) - 1;
if (kBlockCount * kBlockSize > kWordCount) memset(mData + kWordCount, 0, (kBlockCount * kBlockSize - kWordCount) * sizeof(Word));
}
void normalize() {
if (kSize % kWordSize) mData[kWordCount - 1] &= (1ULL << (kSize % kWordSize)) - 1;
if (kBlockCount * kBlockSize > kWordCount) memset(mData + kWordCount, 0, (kBlockCount * kBlockSize - kWordCount) * sizeof(Word));
}
public:
Bitset(const Bitset&) = default;
Bitset(Bitset&&) = default;
template <typename E>
Bitset(const Expression<E>& e) {
for (size_t i = 0; i != kBlockCount; ++i)
blockSet(i, e().blockAt(i));
trim();
}
Bitset& operator=(const Bitset&) = default;
Bitset& operator=(Bitset&&) = default;
template <typename E>
Bitset& operator=(const Expression<E>& e) {
for (size_t i = 0; i != kBlockCount; ++i)
blockSet(i, e().blockAt(i));
trim();
return *this;
}
bool operator[](size_t position) const {
if (position >= kSize) return false;
return mData[position / kWordSize] >> position % kWordSize & 1;
}
bool operator()(size_t position) const {
if (position >= kSize) return false;
return mData[position / kWordSize] >> position % kWordSize & 1;
}
void set(size_t position) {
if (position >= kSize) return;
mData[position / kWordSize] |= 1ULL << position % kWordSize;
}
void unset(size_t position) {
if (position >= kSize) return;
mData[position / kWordSize] &= ~(1ULL << position % kWordSize);
}
void flip(size_t position) {
if (position >= kSize) return;
mData[position / kWordSize] ^= 1ULL << position % kWordSize;
}
size_t size() const {
return kSize;
}
size_t length() const {
return kSize;
}
size_t count() const {
const Block mask = _mm256_set1_epi8(0x0F);
const Block table = _mm256_setr_epi8(
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4
);
Block result = _mm256_setzero_si256();
for (size_t i = 0; i != kBlockCount; ++i) {
Block block = blockAt(i);
Block low = _mm256_shuffle_epi8(table, _mm256_and_si256(block, mask));
Block high = _mm256_shuffle_epi8(table, _mm256_and_si256(_mm256_srli_epi16(block, 4), mask));
result = _mm256_add_epi64(result, _mm256_sad_epu8(_mm256_add_epi8(low, high), _mm256_setzero_si256()));
}
Word parts[kBlockSize];
_mm256_storeu_si256((Block*)parts, result);
return size_t(parts[0] + parts[1] + parts[2] + parts[3]);
}
size_t popcnt() const {
const Block mask = _mm256_set1_epi8(0x0F);
const Block table = _mm256_setr_epi8(
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4
);
Block result = _mm256_setzero_si256();
for (size_t i = 0; i != kBlockCount; ++i) {
Block block = blockAt(i);
Block low = _mm256_shuffle_epi8(table, _mm256_and_si256(block, mask));
Block high = _mm256_shuffle_epi8(table, _mm256_and_si256(_mm256_srli_epi16(block, 4), mask));
result = _mm256_add_epi64(result, _mm256_sad_epu8(_mm256_add_epi8(low, high), _mm256_setzero_si256()));
}
Word parts[kBlockSize];
_mm256_storeu_si256((Block*)parts, result);
return size_t(parts[0] + parts[1] + parts[2] + parts[3]);
}
bool none() const {
const size_t fullBlocks = kSize / (sizeof(Block) * CHAR_BIT);
for (size_t i = 0; i != fullBlocks; ++i) {
Block mask = _mm256_set1_epi8(-1), block = blockAt(i);
if (!_mm256_testz_si256(block, mask)) return false;
}
if (size_t remaining = kSize % (sizeof(Block) * CHAR_BIT)) {
Word maskArray[kBlockSize] = {};
for (size_t i = 0; i != kBlockSize; ++i) {
if (remaining >= kWordSize) {
maskArray[i] = ~0ULL;
remaining -= kWordSize;
} else {
maskArray[i] = (1ULL << remaining) - 1;
break;
}
}
Block mask = _mm256_loadu_si256((const Block*)maskArray), block = blockAt(fullBlocks);
if (!_mm256_testz_si256(block, mask)) return false;
}
return true;
}
bool any() const {
const size_t fullBlocks = kSize / (sizeof(Block) * CHAR_BIT);
for (size_t i = 0; i != fullBlocks; ++i) {
Block mask = _mm256_set1_epi8(-1), block = blockAt(i);
if (!_mm256_testz_si256(block, mask)) return true;
}
if (size_t remaining = kSize % (sizeof(Block) * CHAR_BIT)) {
Word maskArray[kBlockSize] = {};
for (size_t i = 0; i != kBlockSize; ++i) {
if (remaining >= kWordSize) {
maskArray[i] = ~0ULL;
remaining -= kWordSize;
} else {
maskArray[i] = (1ULL << remaining) - 1;
break;
}
}
Block mask = _mm256_loadu_si256((const Block*)maskArray), block = blockAt(fullBlocks);
if (!_mm256_testz_si256(block, mask)) return true;
}
return false;
}
bool all() const {
const size_t fullBlocks = kSize / (sizeof(Block) * CHAR_BIT);
for (size_t i = 0; i != fullBlocks; ++i) {
Block mask = _mm256_set1_epi8(-1), block = blockAt(i);
if (!_mm256_testc_si256(block, mask)) return false;
}
if (size_t remaining = kSize % (sizeof(Block) * CHAR_BIT)) {
Word maskArray[kBlockSize] = {};
for (size_t i = 0; i != kBlockSize; ++i) {
if (remaining >= kWordSize) {
maskArray[i] = ~0ULL;
remaining -= kWordSize;
} else {
maskArray[i] = (1ULL << remaining) - 1;
break;
}
}
Block mask = _mm256_loadu_si256((const Block*)maskArray), block = blockAt(fullBlocks);
if (!_mm256_testc_si256(block, mask)) return false;
}
return true;
}
void set() {
const Block value = _mm256_set1_epi64x(-1);
for (size_t i = 0; i != kBlockCount; ++i)
blockSet(i, value);
trim();
}
void unset() {
const Block value = _mm256_setzero_si256();
for (size_t i = 0; i != kBlockCount; ++i)
blockSet(i, value);
trim();
}
void flip() {
const Block value = _mm256_set1_epi64x(-1);
for (size_t i = 0; i != kBlockCount; ++i)
blockSet(i, _mm256_xor_si256(blockAt(i), value));
trim();
}
template <typename E>
Bitset& operator&=(const Expression<E>& e) {
for (size_t i = 0; i != kBlockCount; ++i)
blockSet(i, _mm256_and_si256(blockAt(i), e().blockAt(i)));
trim();
return *this;
}
template <typename E>
Bitset& operator|=(const Expression<E>& e) {
for (size_t i = 0; i != kBlockCount; ++i)
blockSet(i, _mm256_or_si256(blockAt(i), e().blockAt(i)));
trim();
return *this;
}
template <typename E>
Bitset& operator^=(const Expression<E>& e) {
for (size_t i = 0; i != kBlockCount; ++i)
blockSet(i, _mm256_xor_si256(blockAt(i), e().blockAt(i)));
trim();
return *this;
}
template <typename E>
Bitset& operator-=(const Expression<E>& e) {
for (size_t i = 0; i != kBlockCount; ++i)
blockSet(i, _mm256_andnot_si256(e().blockAt(i), blockAt(i)));
trim();
return *this;
}
friend bool operator==(const Bitset& A, const Bitset& B) {
for (size_t i = 0; i != kBlockCount; ++i) {
Block bxa = _mm256_xor_si256(B.blockAt(i), A.blockAt(i));
if (!_mm256_testz_si256(bxa, bxa)) return false;
}
return true;
}
friend bool operator!=(const Bitset& A, const Bitset& B) {
for (size_t i = 0; i != kBlockCount; ++i) {
Block bxa = _mm256_xor_si256(B.blockAt(i), A.blockAt(i));
if (!_mm256_testz_si256(bxa, bxa)) return true;
}
return false;
}
friend bool operator<=(const Bitset& A, const Bitset& B) {
for (size_t i = 0; i != kBlockCount; ++i)
if (!_mm256_testc_si256(B.blockAt(i), A.blockAt(i)))
return false;
return true;
}
friend bool operator>=(const Bitset& A, const Bitset& B) {
for (size_t i = 0; i != kBlockCount; ++i)
if (!_mm256_testc_si256(A.blockAt(i), B.blockAt(i)))
return false;
return true;
}
friend bool operator<(const Bitset& A, const Bitset& B) {
bool different = false;
for (size_t i = 0; i != kBlockCount; ++i) {
Block a = A.blockAt(i), b = B.blockAt(i);
if (!_mm256_testc_si256(b, a)) return false;
different |= !_mm256_testc_si256(a, b);
}
return different;
}
friend bool operator>(const Bitset& A, const Bitset& B) {
bool different = false;
for (size_t i = 0; i != kBlockCount; ++i) {
Block a = A.blockAt(i), b = B.blockAt(i);
if (!_mm256_testc_si256(a, b)) return false;
different |= !_mm256_testc_si256(b, a);
}
return different;
}
size_t findFirstSet(size_t position) const {
if (position >= kSize) return size_t(-1);
size_t wordIndex = position / kWordSize, bitInWord = position % kWordSize;
size_t blockIndex = wordIndex / kBlockSize, wordInBlock = wordIndex % kBlockSize;
if (Word word = mData[wordIndex] & ~((1ULL << bitInWord) - 1)) {
size_t result = wordIndex * kWordSize + size_t(__builtin_ctzll(word));
return result < kSize ? result : size_t(-1);
}
while (++wordInBlock != kBlockSize) {
size_t current = blockIndex * kBlockSize + wordInBlock;
if (current >= kWordCount) break;
if (Word word = mData[current]) {
size_t result = current * kWordSize + __builtin_ctzll(word);
return result < kSize ? result : size_t(-1);
}
}
while (++blockIndex != kBlockCount) {
Block mask = _mm256_set1_epi64x(-1), block = blockAt(blockIndex);
if (_mm256_testz_si256(block, mask)) continue;
for (size_t i = 0; i != kBlockSize; ++i) {
size_t current = blockIndex * kBlockSize + i;
if (current >= kWordCount) break;
if (Word word = mData[current]) {
size_t result = current * kWordSize + __builtin_ctzll(word);
return result < kSize ? result : size_t(-1);
}
}
}
return size_t(-1);
}
size_t findFirstUnset(size_t position) const {
if (position >= kSize) return size_t(-1);
size_t wordIndex = position / kWordSize, bitInWord = position % kWordSize;
size_t blockIndex = wordIndex / kBlockSize, wordInBlock = wordIndex % kBlockSize;
if (Word word = ~mData[wordIndex] & ~((1ULL << bitInWord) - 1)) {
size_t result = wordIndex * kWordSize + size_t(__builtin_ctzll(word));
return result < kSize ? result : size_t(-1);
}
while (++wordInBlock != kBlockSize) {
size_t current = blockIndex * kBlockSize + wordInBlock;
if (current >= kWordCount) break;
if (Word word = ~mData[current]) {
size_t result = current * kWordSize + __builtin_ctzll(word);
return result < kSize ? result : size_t(-1);
}
}
while (++blockIndex != kBlockCount) {
Block mask = _mm256_set1_epi64x(-1), block = blockAt(blockIndex);
if (_mm256_testc_si256(block, mask)) continue;
for (size_t i = 0; i != kBlockSize; ++i) {
size_t current = blockIndex * kBlockSize + i;
if (current >= kWordCount) break;
if (Word word = ~mData[current]) {
size_t result = current * kWordSize + __builtin_ctzll(word);
return result < kSize ? result : size_t(-1);
}
}
}
return size_t(-1);
}
void set(size_t position, size_t length) {
if (position + length > kSize || !length) return;
size_t wordIndex = position / kWordSize, bitInWord = position % kWordSize;
if (bitInWord) {
size_t headLength = kWordSize - bitInWord < length ? kWordSize - bitInWord : length;
mData[wordIndex++] |= ((1ULL << headLength) - 1) << bitInWord;
length -= headLength;
}
for (const Block value = _mm256_set1_epi64x(-1); length >= sizeof(Block) * CHAR_BIT; ) {
_mm256_storeu_si256((Block*)&mData[wordIndex], value);
length -= sizeof(Block) * CHAR_BIT;
wordIndex += kBlockSize;
}
while (length >= kWordSize) {
mData[wordIndex++] = -1ULL;
length -= kWordSize;
}
mData[wordIndex] |= (1ULL << length) - 1;
}
void unset(size_t position, size_t length) {
if (position + length > kSize || !length) return;
size_t wordIndex = position / kWordSize, bitInWord = position % kWordSize;
if (bitInWord) {
size_t headLength = kWordSize - bitInWord < length ? kWordSize - bitInWord : length;
mData[wordIndex++] &= ~(((1ULL << headLength) - 1) << bitInWord);
length -= headLength;
}
for (const Block value = _mm256_setzero_si256(); length >= sizeof(Block) * CHAR_BIT; ) {
_mm256_storeu_si256((Block*)&mData[wordIndex], value);
length -= sizeof(Block) * CHAR_BIT;
wordIndex += kBlockSize;
}
while (length >= kWordSize) {
mData[wordIndex++] = 0ULL;
length -= kWordSize;
}
mData[wordIndex] &= ~((1ULL << length) - 1);
}
void flip(size_t position, size_t length) {
if (position + length > kSize || !length) return;
size_t wordIndex = position / kWordSize, bitInWord = position % kWordSize;
if (bitInWord) {
size_t headLength = kWordSize - bitInWord < length ? kWordSize - bitInWord : length;
mData[wordIndex++] ^= ((1ULL << headLength) - 1) << bitInWord;
length -= headLength;
}
for (const Block value = _mm256_set1_epi64x(-1); length >= sizeof(Block) * CHAR_BIT; ) {
_mm256_storeu_si256((Block*)&mData[wordIndex], _mm256_xor_si256(_mm256_loadu_si256((const Block*)&mData[wordIndex]), value));
length -= sizeof(Block) * CHAR_BIT;
wordIndex += kBlockSize;
}
while (length >= kWordSize) {
mData[wordIndex++] ^= -1ULL;
length -= kWordSize;
}
mData[wordIndex] ^= (1ULL << length) - 1;
}
Bitset& operator<<=(size_t step) {
if (step >= kSize) return unset(), *this;
size_t wordShift = step / kWordSize, bitShift = step % kWordSize;
if (!bitShift) {
memmove(mData + wordShift, mData, (kWordCount - wordShift) * sizeof(Word));
memset(mData, 0, wordShift * sizeof(Word));
} else {
size_t remaining = kWordCount - wordShift - 1;
Word *destination = mData + kWordCount, *source = destination - wordShift;
__m128i lShift = _mm_cvtsi32_si128(int(bitShift)), rShift = _mm_cvtsi32_si128(int(kWordSize - bitShift));
while (remaining >= kBlockSize) {
destination -= kBlockSize, source -= kBlockSize;
Block low = _mm256_srl_epi64(_mm256_loadu_si256((const Block*)&source[-1]), rShift);
Block high = _mm256_sll_epi64(_mm256_loadu_si256((const Block*)&source[0]), lShift);
_mm256_storeu_si256((Block*)destination, _mm256_or_si256(low, high));
remaining -= kBlockSize;
}
while (remaining) {
--destination, --source;
*destination = (source[0] << bitShift) | (source[-1] >> (kWordSize - bitShift));
--remaining;
}
*--destination = *--source << bitShift;
memset(mData, 0, wordShift * sizeof(Word));
}
trim();
return *this;
}
Bitset& operator>>=(size_t step) {
if (step >= kSize) return unset(), *this;
size_t wordShift = step / kWordSize, bitShift = step % kWordSize;
if (!bitShift) {
memmove(mData, mData + wordShift, (kWordCount - wordShift) * sizeof(Word));
memset(mData + kWordCount - wordShift, 0, wordShift * sizeof(Word));
} else {
size_t remaining = kWordCount - wordShift - 1;
Word *destination = mData, *source = mData + wordShift;
__m128i rShift = _mm_cvtsi32_si128(int(bitShift)), lShift = _mm_cvtsi32_si128(int(kWordSize - bitShift));
while (remaining >= kBlockSize) {
Block low = _mm256_srl_epi64(_mm256_loadu_si256((const Block*)&source[0]), rShift);
Block high = _mm256_sll_epi64(_mm256_loadu_si256((const Block*)&source[1]), lShift);
_mm256_storeu_si256((Block*)destination, _mm256_or_si256(low, high));
destination += kBlockSize, source += kBlockSize;
remaining -= kBlockSize;
}
while (remaining) {
*destination = (source[0] >> bitShift) | (source[1] << (kWordSize - bitShift));
++destination, ++source;
--remaining;
}
*destination = *source >> bitShift;
memset(mData + kWordCount - wordShift, 0, wordShift * sizeof(Word));
}
trim();
return *this;
}
Bitset operator<<(size_t step) const {
Bitset result;
if (step >= kSize) return result;
size_t wordShift = step / kWordSize, bitShift = step % kWordSize;
if (!bitShift) {
memcpy(result.mData + wordShift, mData, (kWordCount - wordShift) * sizeof(Word));
} else {
size_t remaining = kWordCount - wordShift - 1;
Word* destination = result.mData + kWordCount;
const Word* source = mData + kWordCount - wordShift;
__m128i lShift = _mm_cvtsi32_si128(int(bitShift)), rShift = _mm_cvtsi32_si128(int(kWordSize - bitShift));
while (remaining >= kBlockSize) {
destination -= kBlockSize, source -= kBlockSize;
Block low = _mm256_srl_epi64(_mm256_loadu_si256((const Block*)&source[-1]), rShift);
Block high = _mm256_sll_epi64(_mm256_loadu_si256((const Block*)&source[0]), lShift);
_mm256_storeu_si256((Block*)destination, _mm256_or_si256(low, high));
remaining -= kBlockSize;
}
while (remaining) {
--destination, --source;
*destination = (source[0] << bitShift) | (source[-1] >> (kWordSize - bitShift));
--remaining;
}
*--destination = *--source << bitShift;
}
result.trim();
return result;
}
Bitset operator>>(size_t step) const {
Bitset result;
if (step >= kSize) return result;
size_t wordShift = step / kWordSize, bitShift = step % kWordSize;
if (!bitShift) {
memcpy(result.mData, mData + wordShift, (kWordCount - wordShift) * sizeof(Word));
} else {
size_t remaining = kWordCount - wordShift - 1;
Word* destination = result.mData;
const Word* source = mData + wordShift;
__m128i rShift = _mm_cvtsi32_si128(int(bitShift)), lShift = _mm_cvtsi32_si128(int(kWordSize - bitShift));
while (remaining >= kBlockSize) {
Block low = _mm256_srl_epi64(_mm256_loadu_si256((const Block*)&source[0]), rShift);
Block high = _mm256_sll_epi64(_mm256_loadu_si256((const Block*)&source[1]), lShift);
_mm256_storeu_si256((Block*)destination, _mm256_or_si256(low, high));
destination += kBlockSize, source += kBlockSize;
remaining -= kBlockSize;
}
while (remaining) {
*destination = (source[0] >> bitShift) | (source[1] << (kWordSize - bitShift));
++destination, ++source;
--remaining;
}
*destination = *source >> bitShift;
}
result.trim();
return result;
}
};
#endif
:::
性能测试
为了展示我们的向量化 bitset 与一般的标量 bitset 的性能差距,我们进行了一组性能测试。细节如下:
- 处理器:
Intel(R) Core(TM) i7-10700 CPU @ 2.90GHz - 编译器:
g++.exe (x86_64-win32-seh-rev2, Built by MinGW-Builds project) 14.2.0 - 编译参数:
g++ -std=c++20 -march=native -Ofast - 测试参数:
n=2^{23} 每项操作执行10^3 次取总用时,单位为毫秒。
| 操作类型 | 标量用时 | 向量用时 | 提升倍率 | 其它说明 |
|---|---|---|---|---|
| 逻辑运算 | 按位与以及一次赋值 | |||
| 关系查询 | 子集关系查询,构造数据使得前 |
|||
| 区间修改 | 区间赋值为 |
|||
| 状态查询 | 是否全 |
|||
| 位置查找 | 查找下一个 |
|||
| 左右移位 | 左移以及一次赋值 | |||
| 个数统计 | 随机数据 | |||
| 嵌套运算 | 嵌套四次按位与以及一次赋值 |
可以看到,即使在 -Ofast -march=native 的加持下,编译器对标量 for 的自动向量化依旧比不上我们手写的 AVX2 Intrinsics。逻辑运算的提升达到了
而对于位置查找与左右移位这两个非对齐操作,向量化也带来了明显的提升,分别为 -std=c++20 -march=native 编译参数下 std::popcount 会被识别并编译为极快的 POPCNT 硬件指令,尽管如此我们的并行查表法依然凭借每次处理
最为恐怖的是嵌套运算,足足达到了