void SparseDenseInt8MMAvx512()

in src/FbgemmSparseDenseInt8Avx512.cc [314:558]


void SparseDenseInt8MMAvx512(
    int N,
    const std::unique_ptr<BCSRMatrix<>>& bcsr,
    const uint8_t* B,
    int ldb,
    int32_t* C_i32,
    uint8_t* C_u8,
    int ldc,
    trRequantizationParams_t& rParams,
    bool accum,
    int thread_id,
    int num_threads) {
  // gemv
  if (N == 1 && ldb == 1 && ldc == 1 && bcsr->C % 4 == 0) {
    return SparseDenseInt8MVAvx512<FUSE_RELU, Q_GRAN>(
        bcsr, B, ldb, C_i32, C_u8, rParams, accum, thread_id, num_threads);
  }

  // Calcualtes accum ? C += A * B : C = A * B
  constexpr int VLEN_INT8 = 64;
  constexpr int VLEN_INT32 = 16;

  constexpr int colTileSize = BCSRMatrix<>::COLTILE;
  // Number of columns in the sparse matrix A
  int K = bcsr->C;
  int M = bcsr->R;
  assert((K > 0) && "K needs to be positive");
  int kTiles = (K + colTileSize - 1) / colTileSize;
  const int* row_ptr = bcsr->rowBPtr.data();
  const int* col_idx = bcsr->colBIdx.data();
  const int8_t* values = bcsr->values.data();

  constexpr int buffer_size = BCSRMatrix<>::COLTILE * VLEN_INT8;
  static thread_local uint8_t* interleave_buffer_ = nullptr;

  if (interleave_buffer_ == nullptr) {
    interleave_buffer_ =
        static_cast<uint8_t*>(fbgemmAlignedAlloc(64, buffer_size));
  }

  assert(
      (interleave_buffer_ != nullptr) &&
      "interleave_buffer_ cannot be nullptr");

  __m512i one_16bit_v = _mm512_set1_epi16(1);
  int j = 0;
  for (; j < N / VLEN_INT8 * VLEN_INT8; j += VLEN_INT8) {
    for (int kt = 0; kt < kTiles; ++kt) {
      int curKSize = std::min(K - kt * colTileSize, colTileSize);
      interleave4RowsTile<4 /*COLBLOCKS*/>(
          N, curKSize, B + kt * colTileSize * ldb, interleave_buffer_, ldb, j);
      for (int i = 0; i < M; ++i) {
        __m512i c_v[4];
        if (accum || kt > 0) {
          for (int idx = 0; idx < 4; ++idx) {
            c_v[idx] = _mm512_loadu_si512(C_i32 + i * ldb + idx * VLEN_INT32);
          }
        } else {
          for (int idx = 0; idx < 4; ++idx) {
            c_v[idx] = _mm512_set1_epi32(0);
          }
        }

        loopOverReductionDim<2 /*UNROLL*/, 4 /*COLBLOCKS*/>(
            row_ptr + kt * M,
            i,
            col_idx,
            values,
            interleave_buffer_,
            one_16bit_v,
            c_v);

        if (kt == kTiles - 1) {
          // Requantize after last ktile
          __m512i res;
          if (rParams.bias == nullptr) {
            if (rParams.act_zero_point) {
              res = requantizeForMM<FUSE_RELU, false, false, Q_GRAN>(
                  c_v, i, rParams);
            } else {
              res = requantizeForMM<FUSE_RELU, true, false, Q_GRAN>(
                  c_v, i, rParams);
            }
          } else {
            if (rParams.act_zero_point) {
              res = requantizeForMM<FUSE_RELU, false, true, Q_GRAN>(
                  c_v, i, rParams);
            } else {
              res = requantizeForMM<FUSE_RELU, true, true, Q_GRAN>(
                  c_v, i, rParams);
            }
          }
          _mm512_storeu_si512(C_u8 + i * ldc + j, res);
        } else {
          // store the results
          for (int idx = 0; idx < 4; ++idx) {
            _mm512_storeu_si512(C_i32 + i * ldb + idx * VLEN_INT32, c_v[idx]);
          }
        }
      }
    }
  }
  // Handle remainder j loop
  int rem_int8 = N - j;
  int rem_int32 = N % VLEN_INT32;
  int colBlocks = (rem_int8 + VLEN_INT32 - 1) / VLEN_INT32;
  if (rem_int8 > 0) {
    for (int kt = 0; kt < kTiles; ++kt) {
      // last k tile may have less than colTileSize columns of A matrix (aka
      // rows of B)
      int curKSize = std::min(K - kt * colTileSize, colTileSize);
      switch (colBlocks) {
        case 1:
          interleave4RowsTile<1>(
              N,
              curKSize,
              B + kt * colTileSize * ldb,
              interleave_buffer_,
              ldb,
              j);
          break;
        case 2:
          interleave4RowsTile<2>(
              N,
              curKSize,
              B + kt * colTileSize * ldb,
              interleave_buffer_,
              ldb,
              j);
          break;
        case 3:
          interleave4RowsTile<3>(
              N,
              curKSize,
              B + kt * colTileSize * ldb,
              interleave_buffer_,
              ldb,
              j);
          break;
        case 4:
          interleave4RowsTile<4>(
              N,
              curKSize,
              B + kt * colTileSize * ldb,
              interleave_buffer_,
              ldb,
              j);
          break;
        default:
          // not reachable
          break;
      }

      __mmask16 mask_int32_v = (((long long)1) << rem_int32) - 1;
      __mmask64 mask_int8_v = (((long long)1) << rem_int8) - 1;
      for (int i = 0; i < M; ++i) {
        __m512i c_v[4] = {};
        if (accum || kt > 0) {
          int idx = 0;
          for (; idx < rem_int8 / VLEN_INT32; ++idx) {
            c_v[idx] = _mm512_loadu_si512(C_i32 + i * ldb + idx * VLEN_INT32);
          }
          c_v[idx] = _mm512_maskz_loadu_epi32(
              mask_int32_v, C_i32 + i * ldb + idx * VLEN_INT32);
        }

        switch (colBlocks) {
          case 1:
            loopOverReductionDim<3 /*UNROLL*/, 1 /*colBlocks*/>(
                row_ptr + M * kt,
                i,
                col_idx,
                values,
                interleave_buffer_,
                one_16bit_v,
                c_v);
            break;
          case 2:
            loopOverReductionDim<3 /*UNROLL*/, 2 /*colBlocks*/>(
                row_ptr + M * kt,
                i,
                col_idx,
                values,
                interleave_buffer_,
                one_16bit_v,
                c_v);
            break;
          case 3:
            loopOverReductionDim<2 /*UNROLL*/, 3 /*colBlocks*/>(
                row_ptr + M * kt,
                i,
                col_idx,
                values,
                interleave_buffer_,
                one_16bit_v,
                c_v);
            break;
          case 4:
            loopOverReductionDim<2 /*UNROLL*/, 4 /*colBlocks*/>(
                row_ptr + M * kt,
                i,
                col_idx,
                values,
                interleave_buffer_,
                one_16bit_v,
                c_v);
            break;
          default:
            // not reachable
            break;
        }

        if (kt == kTiles - 1) {
          // Requantize after last ktile
          __m512i res;
          if (rParams.bias == nullptr) {
            if (rParams.act_zero_point) {
              res = requantizeForMM<FUSE_RELU, false, false, Q_GRAN>(
                  c_v, i, rParams);
            } else {
              res = requantizeForMM<FUSE_RELU, true, false, Q_GRAN>(
                  c_v, i, rParams);
            }
          } else {
            if (rParams.act_zero_point) {
              res = requantizeForMM<FUSE_RELU, false, true, Q_GRAN>(
                  c_v, i, rParams);
            } else {
              res = requantizeForMM<FUSE_RELU, true, true, Q_GRAN>(
                  c_v, i, rParams);
            }
          }
          _mm512_mask_storeu_epi8(C_u8 + i * ldc + j, mask_int8_v, res);
        } else {
          int idx = 0;
          for (; idx < rem_int8 / VLEN_INT32; ++idx) {
            _mm512_storeu_si512(C_i32 + i * ldb + idx * VLEN_INT32, c_v[idx]);
          }
          _mm512_mask_storeu_epi32(
              C_i32 + i * ldb + idx * VLEN_INT32, mask_int32_v, c_v[idx]);
        }
      }
    }
  }
}