如何 n^2 过百万:题解 P5066 【[Ynoi2014]人人本着正义之名】

· · 题解

我用 暴力 来过这道题!

首先,介绍一下 AVX512 Intrinsics:

AVX512 Intrinsics 是一个方法来在特别小的常数 (0.52)进行一个 512 位数逻辑或者数学操作。这是这个方法的关键:我们是 O(nm),但是常数是 O(1/512),所以不仅可以过也可以卡到最慢点只有 3s。

来应用 AVX512 Intrinsics,在代码里需要加入

#pragma GCC target("avx512f,avx512bw")
#include <immintrin.h>

就可以了。注意这些代码可能在本地无法运行,应为大部分的电脑都不支持 AVX512 Intrinsics,但是洛谷评测机支持就够了。

第一步是把给定的序列进行分块(Ynoi 题,还能是啥?)- 但不是普通的线性分块,是转置成 512 行的分块。假设我们分成 W 块,那么块表存储的值就会像这样:

 0    |  1     |  2     | ... | W-1
 W    | W+1    | W+2    | ... | 2W-1
 2W   | 2W+1   | 2W+2   | ... | 3W-1
 ...
 511W | 511W+1 | 511W+1 | ... | 512W-1

这样的最大好处就是在原序列的往左/右移动一位就变成往左/右移动一整个块,好处理很多。注意在每次操作之前先需要更新最左块左面的块和最右块右面的块。

那么对与一个操作 [l,r],影响的位置会看的大概像这样:(1 = 影响,0 = 没影响)

00000000000000...
00000011111111...
11111111111111...
11111111100000...
00000000000000...0

那么可以直接把每个询问拆成三个块来执行操作,分成以下的一种:

  1. [0, l\mod W],[l\mod W,r\mod W],[r\mod W,W]
  2. [0, r\mod W],[r\mod W,l\mod W],[l\mod W,W]

现在,需要对与一个操作块找到对应影响位置位掩码。怎么找呢?

假设操作块是 [l,r],本操作是 [L,R]

  1. 如果一个在位掩码位置 i 是 1,那么满足 Wi+L\in[l,r)
  2. 所以 l\le Wi+L\ \wedge\ Wi+L<r
  3. 所以 l-L\le Wi\ \wedge\ Wi<r-L
  4. 所以 \frac{l-L}{W}\le i\ \wedge\ i\le\frac{r-L-1}{W}
  5. 所以 \lceil\frac{l-L}{W}\rceil\le i\ \wedge\ i\le\lfloor\frac{r-L-1}{W}\rfloor

如果左端点小于 0,那直接不用做这个操作块,否则影响位掩码的所有左端点到右端点位置都是 1,预处理前缀 1 的位掩码足够来 O(1) 求。

现在做 1~6 操作就很简单了,直接做个 for 循环来更新位置上的值。

如何做 7 操作呢?7 操作等价于统计有多少个 1 位置在位掩码里,但是直接 naive 统计太慢了。看一下 这里的速度,发现最快的一种方法来统计是叫 ”avx512-harley-seal“,所以直接复制过来 这些代码 和 这些代码

然后就好了。

代码:

#pragma GCC optimize("O4")
#pragma GCC target("avx512f,avx512bw")
// writer: w33z8kqrqk8zzzx33
#include <bits/stdc++.h>
#include <immintrin.h>
using namespace std;

// begin fast read template by CYJian (source: https://www.luogu.com.cn/paste/i11c3ppx)

namespace io {
    const int __SIZE = (1 << 21) + 1;
    char ibuf[__SIZE], *iS, *iT, obuf[__SIZE], *oS = obuf, *oT = oS + __SIZE - 1, __c, qu[55]; int __f, qr, _eof;
    #define Gc() (iS == iT ? (iT = (iS = ibuf) + fread (ibuf, 1, __SIZE, stdin), (iS == iT ? EOF : *iS ++)) : *iS ++)
    inline void flush () { fwrite (obuf, 1, oS - obuf, stdout), oS = obuf; }
    inline void gc (char &x) { x = Gc(); }
    inline void pc (char x) { *oS ++ = x; if (oS == oT) flush (); }
    inline void pstr (const char *s) { int __len = strlen(s); for (__f = 0; __f < __len; ++__f) pc (s[__f]); }
    inline void gstr (char *s) { for(__c = Gc(); __c < 32 || __c > 126 || __c == ' ';)  __c = Gc();
        for(; __c > 31 && __c < 127 && __c != ' '; ++s, __c = Gc()) *s = __c; *s = 0; }
    template <class I> inline bool gi (I &x) { _eof = 0;
        for (__f = 1, __c = Gc(); (__c < '0' || __c > '9') && !_eof; __c = Gc()) { if (__c == '-') __f = -1; _eof |= __c == EOF; }
        for (x = 0; __c <= '9' && __c >= '0' && !_eof; __c = Gc()) x = x * 10 + (__c & 15), _eof |= __c == EOF; x *= __f; return !_eof; }
    template <class I> inline void print (I x) { if (!x) pc ('0'); if (x < 0) pc ('-'), x = -x;
        while (x) qu[++ qr] = x % 10 + '0',  x /= 10; while (qr) pc (qu[qr --]); }
    struct Flusher_ {~Flusher_(){flush();}}io_flusher_;
} using io::pc; using io::gc; using io::pstr; using io::gstr; using io::gi; using io::print;

// end fast read template by CYJian

#define iter(i, a, b) for(int i=(a); i<(b); i++)
#define reti(i, a, b) for(int i=(b)-1; i>=(a); i--)
#define rep(i, a) iter(i, 0, a)
#define rep1(i, a) iter(i, 1, (a)+1)
#define all(a) a.begin(), a.end()
#define fi first
#define se second
#define pb push_back

using ll=long long;

using M=__m512i;
union U { M m; uint64_t words[8]; };

inline void setbit(U& u, const unsigned& bit) { u.words[bit>>6] |= 1ull<<(bit&63);}

void RR(U& u) {
    u.words[0] = (u.words[0]>>1) | (u.words[1]<<63);
    u.words[1] = (u.words[1]>>1) | (u.words[2]<<63);
    u.words[2] = (u.words[2]>>1) | (u.words[3]<<63);
    u.words[3] = (u.words[3]>>1) | (u.words[4]<<63);
    u.words[4] = (u.words[4]>>1) | (u.words[5]<<63);
    u.words[5] = (u.words[5]>>1) | (u.words[6]<<63);
    u.words[6] = (u.words[6]>>1) | (u.words[7]<<63);
    u.words[7] = u.words[7]>>1;
}

void LL(U& u) {
    u.words[7] = (u.words[7]<<1) | (u.words[6]>>63);
    u.words[6] = (u.words[6]<<1) | (u.words[5]>>63);
    u.words[5] = (u.words[5]<<1) | (u.words[4]>>63);
    u.words[4] = (u.words[4]<<1) | (u.words[3]>>63);
    u.words[3] = (u.words[3]<<1) | (u.words[2]>>63);
    u.words[2] = (u.words[2]<<1) | (u.words[1]>>63);
    u.words[1] = (u.words[1]<<1) | (u.words[0]>>63);
    u.words[0] = u.words[0]<<1;
}

M zero, one;

int opt, l, r, sum;
U* blocks;
U prefixes[513];
int nblocks;

namespace SimdSum {
struct sse_vector final {
    union {
        __m128i  v;
        uint8_t  u8[16];
        uint16_t u16[8];
        uint32_t u32[4];
        uint64_t u64[2];
    };

    sse_vector() = delete;
    sse_vector(sse_vector&) = delete;

    explicit sse_vector(const __m128i& vec): v(vec) {}
};

__m128i operator&(sse_vector a, sse_vector b) {

    return _mm_and_si128(a.v, b.v);
}

__m128i operator|(sse_vector a, sse_vector b) {

    return _mm_or_si128(a.v, b.v);
}

__m128i operator^(sse_vector a, sse_vector b) {

    return _mm_xor_si128(a.v, b.v);
}

struct shift16 final {
    const unsigned bits;

    shift16() = delete;
    explicit shift16(unsigned bits) : bits(bits) {};
};

__m128i operator>>(const __m128i a, const shift16 amount) {

    return _mm_srli_epi16(a, amount.bits);
}

uint64_t lower_qword(const __m128i v) {

    return _mm_cvtsi128_si64(v);
}

uint64_t higher_qword(const __m128i v) {

    return lower_qword(_mm_srli_si128(v, 8));
}

uint64_t simd_sum_epu64(const __m128i v) {

    return lower_qword(v) + higher_qword(v);
}

uint64_t simd_sum_epu64(const __m256i v) {

    return static_cast<uint64_t>(_mm256_extract_epi64(v, 0))
         + static_cast<uint64_t>(_mm256_extract_epi64(v, 1))
         + static_cast<uint64_t>(_mm256_extract_epi64(v, 2))
         + static_cast<uint64_t>(_mm256_extract_epi64(v, 3));
}

uint64_t simd_sum_epu64(const __m512i v) {

    const __m256i lo = _mm512_extracti64x4_epi64(v, 0);
    const __m256i hi = _mm512_extracti64x4_epi64(v, 1);

    return simd_sum_epu64(lo) + simd_sum_epu64(hi);
}
};

namespace AVX512_harley_seal {

__m512i popcount(const __m512i v)
{
  const __m512i m1 = _mm512_set1_epi8(0x55);
  const __m512i m2 = _mm512_set1_epi8(0x33);
  const __m512i m4 = _mm512_set1_epi8(0x0F);

  const __m512i t1 = _mm512_sub_epi8(v,       (_mm512_srli_epi16(v,  1) & m1));
  const __m512i t2 = _mm512_add_epi8(t1 & m2, (_mm512_srli_epi16(t1, 2) & m2));
  const __m512i t3 = _mm512_add_epi8(t2, _mm512_srli_epi16(t2, 4)) & m4;
  return _mm512_sad_epu8(t3, _mm512_setzero_si512());
}

void CSA(__m512i& h, __m512i& l, __m512i a, __m512i b, __m512i c)
{
    /*
        c b a | l h
        ------+----
        0 0 0 | 0 0
        0 0 1 | 1 0
        0 1 0 | 1 0
        0 1 1 | 0 1
        1 0 0 | 1 0
        1 0 1 | 0 1
        1 1 0 | 0 1
        1 1 1 | 1 1
        l - digit
        h - carry
    */

  l = _mm512_ternarylogic_epi32(c, b, a, 0x96);
  h = _mm512_ternarylogic_epi32(c, b, a, 0xe8);
}

uint64_t popcnt(const U* data, const __m512i& mask, const uint64_t size)
{
  __m512i total     = _mm512_setzero_si512();
  __m512i ones      = _mm512_setzero_si512();
  __m512i twos      = _mm512_setzero_si512();
  __m512i fours     = _mm512_setzero_si512();
  __m512i eights    = _mm512_setzero_si512();
  __m512i sixteens  = _mm512_setzero_si512();
  __m512i twosA, twosB, foursA, foursB, eightsA, eightsB;

  const uint64_t limit = size - size % 16;
  uint64_t i = 0;

  for(; i < limit; i += 16)
  {
    CSA(twosA, ones, ones, data[i+0].m & mask, data[i+1].m & mask);
    CSA(twosB, ones, ones, data[i+2].m & mask, data[i+3].m & mask);
    CSA(foursA, twos, twos, twosA, twosB);
    CSA(twosA, ones, ones, data[i+4].m & mask, data[i+5].m & mask);
    CSA(twosB, ones, ones, data[i+6].m & mask, data[i+7].m & mask);
    CSA(foursB, twos, twos, twosA, twosB);
    CSA(eightsA,fours, fours, foursA, foursB);
    CSA(twosA, ones, ones, data[i+8].m & mask, data[i+9].m & mask);
    CSA(twosB, ones, ones, data[i+10].m & mask, data[i+11].m & mask);
    CSA(foursA, twos, twos, twosA, twosB);
    CSA(twosA, ones, ones, data[i+12].m & mask, data[i+13].m & mask);
    CSA(twosB, ones, ones, data[i+14].m & mask, data[i+15].m & mask);
    CSA(foursB, twos, twos, twosA, twosB);
    CSA(eightsB, fours, fours, foursA, foursB);
    CSA(sixteens, eights, eights, eightsA, eightsB);

    total = _mm512_add_epi64(total, popcount(sixteens));
  }

  total = _mm512_slli_epi64(total, 4);     // * 16
  total = _mm512_add_epi64(total, _mm512_slli_epi64(popcount(eights), 3)); // += 8 * ...
  total = _mm512_add_epi64(total, _mm512_slli_epi64(popcount(fours),  2)); // += 4 * ...
  total = _mm512_add_epi64(total, _mm512_slli_epi64(popcount(twos),   1)); // += 2 * ...
  total = _mm512_add_epi64(total, popcount(ones));

  for(; i < size; i++)
    total = _mm512_add_epi64(total, popcount(data[i].m & mask));

  return SimdSum::simd_sum_epu64(total);
}

} // AVX512_harley_seal

void process(int L, int R) {
    M u;
    int low = (l-L-1+(nblocks<<4))/nblocks-15;
    int high = (r-L-1+(nblocks<<4))/nblocks-16;
    if(opt == 1 || opt == 5 || opt == 6) u = (high < 0) ? one : _mm512_andnot_si512(_mm512_andnot_si512(prefixes[low].m, prefixes[high+1].m), one);
    else u = (high < 0) ? zero : _mm512_andnot_si512(prefixes[low].m, prefixes[high+1].m);
    L++; R++;
    if(opt == 1) iter(i, L, R) blocks[i].m &= u;
    else if(opt == 2) iter(i, L, R) blocks[i].m |= u;
    else if(opt == 3) iter(i, L, R) blocks[i].m |= blocks[i+1].m & u;
    else if(opt == 4) reti(i, L, R) blocks[i].m |= blocks[i-1].m & u;
    else if(opt == 5) iter(i, L, R) blocks[i].m &= blocks[i+1].m | u;
    else if(opt == 6) reti(i, L, R) blocks[i].m &= blocks[i-1].m | u;
    else sum += AVX512_harley_seal::popcnt(blocks+L, u, R-L);
}

signed main() {
    ios_base::sync_with_stdio(false); cin.tie(0);
    int N, Q; gi(N), gi(Q);
    nblocks = max((N+511)/512, 1);
    zero = _mm512_setzero_si512();
    prefixes[0].m = prefixes[1].m = zero;
    prefixes[1].words[0] |= 1;
    iter(i, 2, 513) {
        prefixes[i].m = prefixes[i-1].m;
        prefixes[i].words[(i-1)>>6] |= 1ull<<((i-1)&63);
    }
    one = _mm512_set1_epi32(-1);
    blocks = (U*)_mm_malloc((nblocks+2)*(sizeof(U)), sizeof(U));
    rep(i, N) {
        int k; gi(k);
        int bit = i / nblocks, block = i % nblocks;
        if(k) setbit(blocks[block+1], bit);
    }
    while(Q--) {
        gi(opt), gi(l), gi(r);
        l -= (opt != 4 && opt != 6);
        r -= (opt == 3 || opt == 5);
        blocks[0] = blocks[nblocks]; LL(blocks[0]);
        blocks[nblocks+1] = blocks[1]; RR(blocks[nblocks+1]);
        int v1 = 0, v2 = l%nblocks, v3 = r%nblocks, v4 = nblocks;
        if(v2 > v3) swap(v2, v3);
        sum = 0;
        if(opt == 4 || opt == 6) process(v3, v4), process(v2, v3), process(v1, v2);
        else process(v1, v2), process(v2, v3), process(v3, v4);
        if(opt == 7) print(sum), pc('\n');
    }
}