void SparseDenseInt8MMAvx2()

in src/FbgemmSparseDenseInt8Avx2.cc [53:234]


void SparseDenseInt8MMAvx2(
    int N,
    const std::unique_ptr<BCSRMatrix<>>& bcsr,
    const uint8_t* B,
    int ldb,
    int32_t* C_i32,
    uint8_t* C_u8,
    int ldc,
    trRequantizationParams_t& rParams,
    bool accum,
    int /*thread_id*/,
    int /*num_threads*/) {
  // Calcualtes accum ? C += A * B : C = A * B
  constexpr int VLEN_INT8 = 32;
  constexpr int VLEN_INT32 = 8;
  constexpr int rowBlockSize = BCSRMatrix<>::RB;
  (void)rowBlockSize; // Suppress unused variable warning
  constexpr int colBlockSize = BCSRMatrix<>::CB;

  constexpr int colTileSize = BCSRMatrix<>::COLTILE;
  int K = bcsr->C;
  int M = bcsr->R;
  int kTiles = (K + colTileSize - 1) / colTileSize;

  for (int i = 0; i < M; ++i) {
    if (!accum) {
      int j = 0;
      __m256i c_v = _mm256_set1_epi32(0);
      for (; j < N / VLEN_INT32 * VLEN_INT32; j += VLEN_INT32) {
        _mm256_storeu_si256(
            reinterpret_cast<__m256i*>(C_i32 + i * ldc + j), c_v);
      }
      // Handle remainder
      int rem = N - j;
      if (rem > 0) {
        __m256i mask_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
            &avx2_ps_or_epi32_combined_mask[VLEN_INT32 - rem]));
        _mm256_maskstore_epi32(
            reinterpret_cast<int32_t*>(C_i32 + i * ldc + j), mask_v, c_v);
      }
    }
    for (int kt = 0; kt < kTiles; ++kt) {
      int* row_ptr = bcsr->rowBPtr.data() + kt * M;
      int* col_idx = bcsr->colBIdx.data();
      int8_t* values = bcsr->values.data();
      int curKSize = std::min(K - kt * colTileSize, colTileSize);

      int r = row_ptr[i];
      // int r_end_aligned = row_ptr[i] + (row_ptr[i + 1] - row_ptr[i]) / 4 * 4;
      // unrolled by 1
      for (; r < row_ptr[i + 1]; ++r) {
        // this is needed for correct operation
        assert(rowBlockSize == 1 && "row block size should be 1");
        assert(colBlockSize == 4 && "column block size should be 4");
        int acbr_block = col_idx[r];
        int32_t v = reinterpret_cast<const int32_t*>(values)[r];
        __m256i a_v = _mm256_set1_epi32(v);
        int j = 0;
        for (; j < N / VLEN_INT8 * VLEN_INT8; j += VLEN_INT8) {
          __m256i br_v[4] = {};

          for (int idx = 0;
               idx < std::min(4, curKSize - acbr_block * colBlockSize);
               ++idx) {
            br_v[idx] = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
                B + (acbr_block * colBlockSize + idx + kt * colTileSize) * ldb +
                j));
          }

          // interleave these 4 rows
          interleave_4rows(br_v);

          __m256i one_16bit_v = _mm256_set1_epi16(1);
          __m256i c_v[4];
          for (int idx = 0; idx < 4; ++idx) {
            c_v[idx] = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
                C_i32 + i * ldc + j + idx * VLEN_INT32));
            __m256i c_i16_v = _mm256_maddubs_epi16(br_v[idx], a_v);
            __m256i c_i32_v = _mm256_madd_epi16(one_16bit_v, c_i16_v);
            c_v[idx] = _mm256_add_epi32(c_v[idx], c_i32_v);
            _mm256_storeu_si256(
                reinterpret_cast<__m256i*>(
                    C_i32 + i * ldc + j + idx * VLEN_INT32),
                c_v[idx]);
          }
        }
        // Handle remainder j loop
        int rem = N - j;
        if (rem > 0) {
          __m256i br_v[4] = {};
          for (int idx = 0;
               idx < std::min(4, curKSize - acbr_block * colBlockSize);
               ++idx) {
            uint8_t tmpDest[VLEN_INT8] = {};
            std::memcpy(
                tmpDest,
                B + (acbr_block * colBlockSize + idx + kt * colTileSize) * ldb +
                    j,
                rem);
            br_v[idx] =
                _mm256_loadu_si256(reinterpret_cast<const __m256i*>(tmpDest));
          }
          // interleave these 4 rows
          interleave_4rows(br_v);

          __m256i c_v[4] = {};
          int idx1 = 0;
          for (; idx1 < rem / VLEN_INT32; ++idx1) {
            c_v[idx1] = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
                C_i32 + i * ldc + j + idx1 * 8));
          }
          int rem_int32 = rem - idx1 * VLEN_INT32;
          __m256i mask_int32_v;
          if (rem_int32 > 0) {
            mask_int32_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
                &avx2_ps_or_epi32_combined_mask[VLEN_INT32 - rem_int32]));
            c_v[idx1] = _mm256_maskload_epi32(
                reinterpret_cast<const int*>(
                    C_i32 + i * ldc + j + idx1 * VLEN_INT32),
                mask_int32_v);
          }

          __m256i one_16bit_v = _mm256_set1_epi16(1);
          for (int idx = 0; idx < 4; ++idx) {
            __m256i c_i16_v = _mm256_maddubs_epi16(br_v[idx], a_v);
            __m256i c_i32_v = _mm256_madd_epi16(one_16bit_v, c_i16_v);
            c_v[idx] = _mm256_add_epi32(c_v[idx], c_i32_v);
          }

          int idx2 = 0;
          for (; idx2 < rem / VLEN_INT32; ++idx2) {
            _mm256_storeu_si256(
                reinterpret_cast<__m256i*>(
                    C_i32 + i * ldc + j + idx2 * VLEN_INT32),
                c_v[idx2]);
          }
          if (rem_int32 > 0) {
            _mm256_maskstore_epi32(
                reinterpret_cast<int*>(C_i32 + i * ldc + j + idx2 * VLEN_INT32),
                mask_int32_v,
                c_v[idx2]);
          }
        }
      }
    }
  }

  block_type_t block{0, M, 0, N};
  if (rParams.bias == nullptr) {
    if (rParams.act_zero_point) {
      trRequantizeOpt<
          FUSE_RELU,
          /*ACT_SYMMETRIC*/ false,
          /*WEIGHT_SYMMETRIC*/ true,
          /*HAS_BIAS*/ false,
          Q_GRAN>(C_u8, C_i32, block, ldc, ldc, rParams);
    } else {
      trRequantizeOpt<
          FUSE_RELU,
          /*ACT_SYMMETRIC*/ true,
          /*WEIGHT_SYMMETRIC*/ true,
          /*HAS_BIAS*/ false,
          Q_GRAN>(C_u8, C_i32, block, ldc, ldc, rParams);
    }
  } else {
    if (rParams.act_zero_point) {
      trRequantizeOpt<
          FUSE_RELU,
          /*ACT_SYMMETRIC*/ false,
          /*WEIGHT_SYMMETRIC*/ true,
          /*HAS_BIAS*/ true,
          Q_GRAN>(C_u8, C_i32, block, ldc, ldc, rParams);
    } else {
      trRequantizeOpt<
          FUSE_RELU,
          /*ACT_SYMMETRIC*/ true,
          /*WEIGHT_SYMMETRIC*/ true,
          /*HAS_BIAS*/ true,
          Q_GRAN>(C_u8, C_i32, block, ldc, ldc, rParams);
    }
  }
}