in src/FbgemmI64.cc [403:540]
void cblas_gemm_i64_i64acc(
matrix_op_t transa,
matrix_op_t transb,
int M,
int N,
int K,
const int64_t* A,
int lda,
const int64_t* B,
int ldb,
bool accumulate,
int64_t* C,
int ldc) {
cpuinfo_initialize();
if (!fbgemmHasAvx512Support()) {
cblas_gemm_i64_i64acc_ref(
transa, transb, M, N, K, A, lda, B, ldb, accumulate, C, ldc);
return;
}
constexpr int MCB = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::MCB;
constexpr int NCB = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::NCB;
constexpr int KCB = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::KCB;
constexpr int MR = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::MR;
constexpr int NR = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::NR;
static_assert(MCB % MR == 0, "MR must divide MCB");
static_assert(NCB % NR == 0, "NR must divide NCB");
constexpr int VLEN =
simd_info<inst_set_t::avx512>::WIDTH_BYTES / sizeof(int64_t);
static_assert(NR % VLEN == 0, "VLEN must divide NR");
using CodeGenType = CodeGenBase<int64_t, int64_t, int64_t, int64_t>;
CodeGenType codeObj;
CodeGenType::jit_micro_kernel_fp fn =
codeObj.getOrCreate<inst_set_t::avx512>(true /* accum */, MCB, NCB, KCB);
CodeGenType::jit_micro_kernel_fp fn_noacc;
if (!accumulate) {
fn_noacc = codeObj.getOrCreate<inst_set_t::avx512>(
false /* accum */, MCB, NCB, KCB);
}
vector<int64_t> At, Bt;
// TODO: handle transpose during packing
if (transa == matrix_op_t::Transpose) {
At.resize(M * K);
for (int i = 0; i < M; ++i) {
for (int k = 0; k < K; ++k) {
At.at(i * K + k) = A[i + k * lda];
}
}
A = At.data();
lda = K;
}
if (transb == matrix_op_t::Transpose) {
Bt.resize(K * N);
for (int k = 0; k < K; ++k) {
for (int j = 0; j < N; ++j) {
Bt.at(k * N + j) = B[k + j * ldb];
}
}
B = Bt.data();
ldb = N;
}
alignas(64) array<int64_t, MCB * KCB> packA;
alignas(64) array<int64_t, KCB * NCB> packB;
alignas(64) array<int64_t, MCB * NCB> packC;
for (int ic = 0; ic < M; ic += MCB) {
for (int kc = 0; kc < K; kc += KCB) {
// pack A
for (int i = 0; i < std::min(MCB, M - ic); ++i) {
memcpy(
&packA[i * KCB],
A + (ic + i) * lda + kc,
std::min(K - kc, KCB) * sizeof(int64_t));
}
for (int jc = 0; jc < N; jc += NCB) {
// pack B
for (int i = 0; i < std::min(KCB, K - kc); ++i) {
memcpy(
&packB[i * NCB],
B + (kc + i) * ldb + jc,
std::min(NCB, N - jc) * sizeof(int64_t));
}
if (M - ic >= MCB && N - jc >= NCB) {
if (kc == 0 && !accumulate) {
fn_noacc(
packA.data(),
packB.data(),
packB.data(),
C + ic * ldc + jc,
std::min(KCB, K - kc),
ldc);
} else {
fn(packA.data(),
packB.data(),
packB.data(),
C + ic * ldc + jc,
std::min(KCB, K - kc),
ldc);
}
} else {
// remainder
if (kc == 0 && !accumulate) {
fn_noacc(
packA.data(),
packB.data(),
packB.data(),
packC.data(),
std::min(KCB, K - kc),
NCB);
} else {
for (int i = 0; i < std::min(MCB, M - ic); ++i) {
memcpy(
&packC[i * NCB],
C + (ic + i) * ldc + jc,
std::min(NCB, N - jc) * sizeof(int64_t));
}
fn(packA.data(),
packB.data(),
packB.data(),
packC.data(),
std::min(KCB, K - kc),
NCB);
}
for (int i = 0; i < std::min(MCB, M - ic); ++i) {
memcpy(
C + (ic + i) * ldc + jc,
&packC[i * NCB],
std::min(NCB, N - jc) * sizeof(int64_t));
}
}
} // jc
} // kc
} // ic
}