in src/FbgemmSparseDenseInt8Avx512.cc [314:558]
void SparseDenseInt8MMAvx512(
int N,
const std::unique_ptr<BCSRMatrix<>>& bcsr,
const uint8_t* B,
int ldb,
int32_t* C_i32,
uint8_t* C_u8,
int ldc,
trRequantizationParams_t& rParams,
bool accum,
int thread_id,
int num_threads) {
// gemv
if (N == 1 && ldb == 1 && ldc == 1 && bcsr->C % 4 == 0) {
return SparseDenseInt8MVAvx512<FUSE_RELU, Q_GRAN>(
bcsr, B, ldb, C_i32, C_u8, rParams, accum, thread_id, num_threads);
}
// Calcualtes accum ? C += A * B : C = A * B
constexpr int VLEN_INT8 = 64;
constexpr int VLEN_INT32 = 16;
constexpr int colTileSize = BCSRMatrix<>::COLTILE;
// Number of columns in the sparse matrix A
int K = bcsr->C;
int M = bcsr->R;
assert((K > 0) && "K needs to be positive");
int kTiles = (K + colTileSize - 1) / colTileSize;
const int* row_ptr = bcsr->rowBPtr.data();
const int* col_idx = bcsr->colBIdx.data();
const int8_t* values = bcsr->values.data();
constexpr int buffer_size = BCSRMatrix<>::COLTILE * VLEN_INT8;
static thread_local uint8_t* interleave_buffer_ = nullptr;
if (interleave_buffer_ == nullptr) {
interleave_buffer_ =
static_cast<uint8_t*>(fbgemmAlignedAlloc(64, buffer_size));
}
assert(
(interleave_buffer_ != nullptr) &&
"interleave_buffer_ cannot be nullptr");
__m512i one_16bit_v = _mm512_set1_epi16(1);
int j = 0;
for (; j < N / VLEN_INT8 * VLEN_INT8; j += VLEN_INT8) {
for (int kt = 0; kt < kTiles; ++kt) {
int curKSize = std::min(K - kt * colTileSize, colTileSize);
interleave4RowsTile<4 /*COLBLOCKS*/>(
N, curKSize, B + kt * colTileSize * ldb, interleave_buffer_, ldb, j);
for (int i = 0; i < M; ++i) {
__m512i c_v[4];
if (accum || kt > 0) {
for (int idx = 0; idx < 4; ++idx) {
c_v[idx] = _mm512_loadu_si512(C_i32 + i * ldb + idx * VLEN_INT32);
}
} else {
for (int idx = 0; idx < 4; ++idx) {
c_v[idx] = _mm512_set1_epi32(0);
}
}
loopOverReductionDim<2 /*UNROLL*/, 4 /*COLBLOCKS*/>(
row_ptr + kt * M,
i,
col_idx,
values,
interleave_buffer_,
one_16bit_v,
c_v);
if (kt == kTiles - 1) {
// Requantize after last ktile
__m512i res;
if (rParams.bias == nullptr) {
if (rParams.act_zero_point) {
res = requantizeForMM<FUSE_RELU, false, false, Q_GRAN>(
c_v, i, rParams);
} else {
res = requantizeForMM<FUSE_RELU, true, false, Q_GRAN>(
c_v, i, rParams);
}
} else {
if (rParams.act_zero_point) {
res = requantizeForMM<FUSE_RELU, false, true, Q_GRAN>(
c_v, i, rParams);
} else {
res = requantizeForMM<FUSE_RELU, true, true, Q_GRAN>(
c_v, i, rParams);
}
}
_mm512_storeu_si512(C_u8 + i * ldc + j, res);
} else {
// store the results
for (int idx = 0; idx < 4; ++idx) {
_mm512_storeu_si512(C_i32 + i * ldb + idx * VLEN_INT32, c_v[idx]);
}
}
}
}
}
// Handle remainder j loop
int rem_int8 = N - j;
int rem_int32 = N % VLEN_INT32;
int colBlocks = (rem_int8 + VLEN_INT32 - 1) / VLEN_INT32;
if (rem_int8 > 0) {
for (int kt = 0; kt < kTiles; ++kt) {
// last k tile may have less than colTileSize columns of A matrix (aka
// rows of B)
int curKSize = std::min(K - kt * colTileSize, colTileSize);
switch (colBlocks) {
case 1:
interleave4RowsTile<1>(
N,
curKSize,
B + kt * colTileSize * ldb,
interleave_buffer_,
ldb,
j);
break;
case 2:
interleave4RowsTile<2>(
N,
curKSize,
B + kt * colTileSize * ldb,
interleave_buffer_,
ldb,
j);
break;
case 3:
interleave4RowsTile<3>(
N,
curKSize,
B + kt * colTileSize * ldb,
interleave_buffer_,
ldb,
j);
break;
case 4:
interleave4RowsTile<4>(
N,
curKSize,
B + kt * colTileSize * ldb,
interleave_buffer_,
ldb,
j);
break;
default:
// not reachable
break;
}
__mmask16 mask_int32_v = (((long long)1) << rem_int32) - 1;
__mmask64 mask_int8_v = (((long long)1) << rem_int8) - 1;
for (int i = 0; i < M; ++i) {
__m512i c_v[4] = {};
if (accum || kt > 0) {
int idx = 0;
for (; idx < rem_int8 / VLEN_INT32; ++idx) {
c_v[idx] = _mm512_loadu_si512(C_i32 + i * ldb + idx * VLEN_INT32);
}
c_v[idx] = _mm512_maskz_loadu_epi32(
mask_int32_v, C_i32 + i * ldb + idx * VLEN_INT32);
}
switch (colBlocks) {
case 1:
loopOverReductionDim<3 /*UNROLL*/, 1 /*colBlocks*/>(
row_ptr + M * kt,
i,
col_idx,
values,
interleave_buffer_,
one_16bit_v,
c_v);
break;
case 2:
loopOverReductionDim<3 /*UNROLL*/, 2 /*colBlocks*/>(
row_ptr + M * kt,
i,
col_idx,
values,
interleave_buffer_,
one_16bit_v,
c_v);
break;
case 3:
loopOverReductionDim<2 /*UNROLL*/, 3 /*colBlocks*/>(
row_ptr + M * kt,
i,
col_idx,
values,
interleave_buffer_,
one_16bit_v,
c_v);
break;
case 4:
loopOverReductionDim<2 /*UNROLL*/, 4 /*colBlocks*/>(
row_ptr + M * kt,
i,
col_idx,
values,
interleave_buffer_,
one_16bit_v,
c_v);
break;
default:
// not reachable
break;
}
if (kt == kTiles - 1) {
// Requantize after last ktile
__m512i res;
if (rParams.bias == nullptr) {
if (rParams.act_zero_point) {
res = requantizeForMM<FUSE_RELU, false, false, Q_GRAN>(
c_v, i, rParams);
} else {
res = requantizeForMM<FUSE_RELU, true, false, Q_GRAN>(
c_v, i, rParams);
}
} else {
if (rParams.act_zero_point) {
res = requantizeForMM<FUSE_RELU, false, true, Q_GRAN>(
c_v, i, rParams);
} else {
res = requantizeForMM<FUSE_RELU, true, true, Q_GRAN>(
c_v, i, rParams);
}
}
_mm512_mask_storeu_epi8(C_u8 + i * ldc + j, mask_int8_v, res);
} else {
int idx = 0;
for (; idx < rem_int8 / VLEN_INT32; ++idx) {
_mm512_storeu_si512(C_i32 + i * ldb + idx * VLEN_INT32, c_v[idx]);
}
_mm512_mask_storeu_epi32(
C_i32 + i * ldb + idx * VLEN_INT32, mask_int32_v, c_v[idx]);
}
}
}
}
}