int SwissTable::early_filter_imp_avx2_x32()

in cpp/src/arrow/compute/key_map_internal_avx2.cc [207:366]


int SwissTable::early_filter_imp_avx2_x32(const int num_hashes, const uint32_t* hashes,
                                          uint8_t* out_match_bitvector,
                                          uint8_t* out_local_slots) const {
  constexpr int unroll = 32;

  // There is a limit on the number of input blocks,
  // because we want to store all their data in a set of AVX2 registers.
  ARROW_DCHECK(log_blocks_ <= 4);

  // Remember that block bytes and group id bytes are in opposite orders in memory of hash
  // table. We put them in the same order.
  __m256i vblock_byte0, vblock_byte1, vblock_byte2, vblock_byte3, vblock_byte4,
      vblock_byte5, vblock_byte6, vblock_byte7;
  // What we output if there is no match in the block
  __m256i vslot_empty_or_end;

  constexpr uint32_t k4ByteSequence_0_4_8_12 = 0x0c080400;
  constexpr uint32_t k4ByteSequence_1_5_9_13 = 0x0d090501;
  constexpr uint32_t k4ByteSequence_2_6_10_14 = 0x0e0a0602;
  constexpr uint32_t k4ByteSequence_3_7_11_15 = 0x0f0b0703;
  constexpr uint64_t kByteSequence7DownTo0 = 0x0001020304050607ULL;
  constexpr uint64_t kByteSequence15DownTo8 = 0x08090A0B0C0D0E0FULL;

  // Bit unpack group ids into 1B.
  // Assemble the sequence of block bytes.
  uint64_t block_bytes[16];
  const int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
  const int num_block_bytes = num_block_bytes_from_num_groupid_bits(num_groupid_bits);
  for (int i = 0; i < (1 << log_blocks_); ++i) {
    uint64_t in_blockbytes =
        *reinterpret_cast<const uint64_t*>(block_data(i, num_block_bytes));
    block_bytes[i] = in_blockbytes;
  }

  // Split a sequence of 64-bit words into SIMD vectors holding individual bytes
  __m256i vblock_words0 =
      _mm256_loadu_si256(reinterpret_cast<const __m256i*>(block_bytes) + 0);
  __m256i vblock_words1 =
      _mm256_loadu_si256(reinterpret_cast<const __m256i*>(block_bytes) + 1);
  __m256i vblock_words2 =
      _mm256_loadu_si256(reinterpret_cast<const __m256i*>(block_bytes) + 2);
  __m256i vblock_words3 =
      _mm256_loadu_si256(reinterpret_cast<const __m256i*>(block_bytes) + 3);
  // Reverse the bytes in blocks
  __m256i vshuffle_const =
      _mm256_setr_epi64x(kByteSequence7DownTo0, kByteSequence15DownTo8,
                         kByteSequence7DownTo0, kByteSequence15DownTo8);
  vblock_words0 = _mm256_shuffle_epi8(vblock_words0, vshuffle_const);
  vblock_words1 = _mm256_shuffle_epi8(vblock_words1, vshuffle_const);
  vblock_words2 = _mm256_shuffle_epi8(vblock_words2, vshuffle_const);
  vblock_words3 = _mm256_shuffle_epi8(vblock_words3, vshuffle_const);
  split_bytes_avx2(vblock_words0, vblock_words1, vblock_words2, vblock_words3,
                   vblock_byte0, vblock_byte1, vblock_byte2, vblock_byte3, vblock_byte4,
                   vblock_byte5, vblock_byte6, vblock_byte7);

  // Calculate the slot to output when there is no match in a block.
  // It will be the index of the first empty slot or 7 (the number of slots in block)
  // if there are no empty slots.
  vslot_empty_or_end = _mm256_set1_epi8(7);
  {
    __m256i vis_empty;
#define CMP(VBLOCKBYTE, BYTENUM)                                                         \
  vis_empty =                                                                            \
      _mm256_cmpeq_epi8(VBLOCKBYTE, _mm256_set1_epi8(static_cast<unsigned char>(0x80))); \
  vslot_empty_or_end =                                                                   \
      _mm256_blendv_epi8(vslot_empty_or_end, _mm256_set1_epi8(BYTENUM), vis_empty);
    CMP(vblock_byte7, 7);
    CMP(vblock_byte6, 6);
    CMP(vblock_byte5, 5);
    CMP(vblock_byte4, 4);
    CMP(vblock_byte3, 3);
    CMP(vblock_byte2, 2);
    CMP(vblock_byte1, 1);
    CMP(vblock_byte0, 0);
#undef CMP
  }
  __m256i vblock_is_full = _mm256_andnot_si256(
      _mm256_cmpeq_epi8(vblock_byte7, _mm256_set1_epi8(static_cast<unsigned char>(0x80))),
      _mm256_set1_epi8(static_cast<unsigned char>(0xff)));

  const int block_id_mask = (1 << log_blocks_) - 1;

  for (int i = 0; i < num_hashes / unroll; ++i) {
    __m256i vhash0 =
        _mm256_loadu_si256(reinterpret_cast<const __m256i*>(hashes) + 4 * i + 0);
    __m256i vhash1 =
        _mm256_loadu_si256(reinterpret_cast<const __m256i*>(hashes) + 4 * i + 1);
    __m256i vhash2 =
        _mm256_loadu_si256(reinterpret_cast<const __m256i*>(hashes) + 4 * i + 2);
    __m256i vhash3 =
        _mm256_loadu_si256(reinterpret_cast<const __m256i*>(hashes) + 4 * i + 3);

    // We will get input in byte lanes in the order: [0, 8, 16, 24, 1, 9, 17, 25, 2, 10,
    // 18, 26, ...]
    vhash0 = _mm256_or_si256(_mm256_srli_epi32(vhash0, 16),
                             _mm256_and_si256(vhash2, _mm256_set1_epi32(0xffff0000)));
    vhash1 = _mm256_or_si256(_mm256_srli_epi32(vhash1, 16),
                             _mm256_and_si256(vhash3, _mm256_set1_epi32(0xffff0000)));
    __m256i vstamp_A = _mm256_and_si256(_mm256_srli_epi32(vhash0, 16 - log_blocks_ - 7),
                                        _mm256_set1_epi16(0x7f));
    __m256i vstamp_B = _mm256_and_si256(_mm256_srli_epi32(vhash1, 16 - log_blocks_ - 7),
                                        _mm256_set1_epi16(0x7f));
    __m256i vstamp = _mm256_or_si256(vstamp_A, _mm256_slli_epi16(vstamp_B, 8));
    __m256i vblock_id_A = _mm256_and_si256(_mm256_srli_epi32(vhash0, 16 - log_blocks_),
                                           _mm256_set1_epi16(block_id_mask));
    __m256i vblock_id_B = _mm256_and_si256(_mm256_srli_epi32(vhash1, 16 - log_blocks_),
                                           _mm256_set1_epi16(block_id_mask));
    __m256i vblock_id = _mm256_or_si256(vblock_id_A, _mm256_slli_epi16(vblock_id_B, 8));

    // Visit all block bytes in reverse order (overwriting data on multiple matches)
    //
    // Always set match found to true for full blocks.
    //
    __m256i vmatch_found = _mm256_shuffle_epi8(vblock_is_full, vblock_id);
    __m256i vslot_id = _mm256_shuffle_epi8(vslot_empty_or_end, vblock_id);
#define CMP(VBLOCK_BYTE, BYTENUM)                                               \
  {                                                                             \
    __m256i vcmp =                                                              \
        _mm256_cmpeq_epi8(_mm256_shuffle_epi8(VBLOCK_BYTE, vblock_id), vstamp); \
    vmatch_found = _mm256_or_si256(vmatch_found, vcmp);                         \
    vslot_id = _mm256_blendv_epi8(vslot_id, _mm256_set1_epi8(BYTENUM), vcmp);   \
  }
    CMP(vblock_byte7, 7);
    CMP(vblock_byte6, 6);
    CMP(vblock_byte5, 5);
    CMP(vblock_byte4, 4);
    CMP(vblock_byte3, 3);
    CMP(vblock_byte2, 2);
    CMP(vblock_byte1, 1);
    CMP(vblock_byte0, 0);
#undef CMP

    // So far the output is in the order: [0, 8, 16, 24, 1, 9, 17, 25, 2, 10, 18, 26, ...]
    vmatch_found = _mm256_shuffle_epi8(
        vmatch_found,
        _mm256_setr_epi32(k4ByteSequence_0_4_8_12, k4ByteSequence_1_5_9_13,
                          k4ByteSequence_2_6_10_14, k4ByteSequence_3_7_11_15,
                          k4ByteSequence_0_4_8_12, k4ByteSequence_1_5_9_13,
                          k4ByteSequence_2_6_10_14, k4ByteSequence_3_7_11_15));
    // Now it is: [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, | 4, 5, 6, 7,
    // 12, 13, 14, 15, ...]
    vmatch_found = _mm256_permutevar8x32_epi32(vmatch_found,
                                               _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));

    // Repeat the same permutation for slot ids
    vslot_id = _mm256_shuffle_epi8(
        vslot_id, _mm256_setr_epi32(k4ByteSequence_0_4_8_12, k4ByteSequence_1_5_9_13,
                                    k4ByteSequence_2_6_10_14, k4ByteSequence_3_7_11_15,
                                    k4ByteSequence_0_4_8_12, k4ByteSequence_1_5_9_13,
                                    k4ByteSequence_2_6_10_14, k4ByteSequence_3_7_11_15));
    vslot_id =
        _mm256_permutevar8x32_epi32(vslot_id, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
    _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_local_slots) + i, vslot_id);

    reinterpret_cast<uint32_t*>(out_match_bitvector)[i] =
        _mm256_movemask_epi8(vmatch_found);
  }

  return num_hashes - (num_hashes % unroll);
}