src/FbgemmI64.cc (409 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#define FBGEMM_EXPORTS
#include "fbgemm/FbgemmI64.h"
#include <immintrin.h>
#include <cmath>
#include <iostream>
#include <vector>
#include "./GenerateKernel.h"
#include "./RefImplementations.h"
#include "fbgemm/PackingTraits-inl.h"
using namespace std;
namespace fbgemm {
/**
* Generate AVX2 instructions for computing block in the rank-k update of 32-bit
* Accmulation kernel.
*/
template <>
template <inst_set_t instSet>
void CodeGenBase<int64_t, int64_t, int64_t, int64_t>::genComputeBlock(
x86::Emitter* a,
x86::Gp buffer_A,
x86::Gp buffer_B,
x86::Gp B_pf,
int rowRegs,
int colRegs,
int lda) {
using VecRegT = typename simd_info<instSet>::vec_reg_t;
constexpr int vectorLen = simd_info<instSet>::WIDTH_BITS / 64;
// used for matrix B
VecRegT BReg(31);
// temporary register
VecRegT res1(30);
for (int j = 0; j < colRegs; ++j) {
// load B
a->vmovaps(
BReg,
x86::Mem(
buffer_B,
j * vectorLen * sizeof(int64_t),
simd_info<instSet>::WIDTH_BYTES));
// load A, broadcast and fmas
for (int i = 0; i < rowRegs; ++i) {
a->vpmullq(
res1,
BReg,
x86::qword_ptr(buffer_A, (i * lda) * sizeof(int64_t))._1to8());
a->vpaddq(VecRegT(i * colRegs + j), res1, VecRegT(i * colRegs + j));
}
// TODO: need to tune
a->prefetcht0(x86::dword_ptr(B_pf, j * vectorLen * sizeof(int64_t)));
}
}
/**
* Generate AVX2 instructions for storing the C registers back to the memory in
* 32-bit Accumulation kernel.
*/
template <>
template <inst_set_t instSet>
void CodeGenBase<int64_t, int64_t, int64_t, int64_t>::storeCRegs(
x86::Emitter* a,
int rowRegs,
int colRegs,
x86::Gp C_Offset,
x86::Gp ldcReg,
bool accum) {
using VecT = typename simd_info<instSet>::vec_reg_t;
static constexpr int vectorLen = simd_info<instSet>::WIDTH_BITS / 64;
for (int i = 0; i < rowRegs; ++i) {
if (i != 0) {
a->add(C_Offset, ldcReg);
} else {
a->xor_(C_Offset.r32(), C_Offset.r32());
}
for (int j = 0; j < colRegs; ++j) {
if (accum) {
a->vpaddq(
VecT(i * colRegs + j),
VecT(i * colRegs + j),
x86::dword_ptr(
a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int64_t)));
}
a->vmovups(
x86::dword_ptr(
a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int64_t)),
VecT(i * colRegs + j));
}
}
}
/**
* Get or Create the avx512 instructions for int64_t GEMM macro-kernel.
*/
template <>
template <inst_set_t instSet>
CodeGenBase<int64_t, int64_t, int64_t, int64_t>::jit_micro_kernel_fp
CodeGenBase<int64_t, int64_t, int64_t, int64_t>::getOrCreate(
bool accum,
int32_t mc,
int32_t nc,
int32_t /* unused */) {
static constexpr int vectorLen = simd_info<instSet>::WIDTH_BITS / 64;
tuple<bool, int, int, int, int, int, int> kernelSig;
int kBlock;
int nBlock;
int mRegBlockSize;
int nRegBlockSize;
if (blocking_params) {
kBlock = blocking_params->KCB;
nBlock = blocking_params->NCB;
mRegBlockSize = blocking_params->MR;
nRegBlockSize = blocking_params->NR;
} else {
kBlock = PackingTraits<int64_t, int64_t, instSet>::KCB;
nBlock = PackingTraits<int64_t, int64_t, instSet>::NCB;
mRegBlockSize = PackingTraits<int64_t, int64_t, instSet>::MR;
nRegBlockSize = PackingTraits<int64_t, int64_t, instSet>::NR;
}
kernelSig =
make_tuple(accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize);
return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
asmjit::CodeHolder code;
code.init(runtime().environment());
x86::Assembler assembler(&code);
x86::Emitter* a = assembler.as<x86::Emitter>();
#ifdef FBGEMM_LOG_CODE
// generated code logging
FILE* codeLogfile = fopen(
getCodeLoggingFile<instSet>(
accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize)
.c_str(),
"w");
asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
if (codeLogger) {
code.setLogger(codeLogger);
}
#endif
const int maxMRegs = mRegBlockSize;
(void)maxMRegs; // Suppress unused variable warning
const int maxNRegs = nRegBlockSize / vectorLen;
assert(
maxMRegs * maxNRegs <= 30 &&
"MR*(NR*64/512) \
must be <= 29 (available registers constraint)");
const int mRegBlocks = mc / mRegBlockSize;
const int mRegBlocksRem = mc % mRegBlockSize;
// arguments to the function created
x86::Gp buffer_A = a->zdi();
x86::Gp buffer_B = a->zsi();
x86::Gp B_pf = a->zdx();
x86::Gp CBase = a->zcx();
x86::Gp kSize = a->gpz(8);
x86::Gp ldcReg = a->gpz(9);
asmjit::FuncDetail func;
func.init(
asmjit::FuncSignatureT<
void,
int64_t*,
int64_t*,
int64_t*,
int64_t*,
int,
int>(asmjit::CallConv::kIdHost),
a->environment());
asmjit::FuncFrame frame;
frame.init(func);
frame.setDirtyRegs(
x86::Reg::kGroupVec,
asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) |
asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31));
frame.setDirtyRegs(
x86::Reg::kGroupGp,
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
asmjit::FuncArgsAssignment args(&func);
args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
args.updateFuncFrame(frame);
frame.finalize();
a->emitProlog(frame);
a->emitArgsAssignment(frame, args);
asmjit::Label LoopMBlocks = a->newLabel();
asmjit::Label LoopNBlocks = a->newLabel();
asmjit::Label Loopk = a->newLabel();
x86::Gp buffer_B_saved = a->gpz(10);
x86::Gp C_Offset = a->gpz(11);
x86::Gp B_pf_saved = a->gpz(12);
x86::Gp iIdx = a->gpz(13);
x86::Gp jIdx = a->gpz(14);
x86::Gp kIdx = a->gpz(15);
a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int64_t)));
a->imul(kSize, kSize, static_cast<asmjit::Imm>(sizeof(int64_t)));
// save B_buffer address
a->mov(buffer_B_saved, buffer_B);
a->mov(B_pf_saved, B_pf);
int currColRegs = nc / vectorLen;
int colRegs = std::min(currColRegs, maxNRegs);
if (mRegBlocks > 0) {
// move 0 to iteration variables
a->xor_(iIdx.r32(), iIdx.r32());
a->bind(LoopMBlocks);
a->inc(iIdx);
a->xor_(jIdx.r32(), jIdx.r32());
a->bind(LoopNBlocks);
a->inc(jIdx);
int rowRegs = mRegBlockSize;
// init C registers
initCRegs(a, rowRegs, colRegs);
// init k loop index
a->xor_(kIdx.r32(), kIdx.r32());
a->bind(Loopk);
// k is incremented by 1
a->add(kIdx, static_cast<asmjit::Imm>(sizeof(int64_t)));
genComputeBlock<instSet>(
a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
// update buffer_A address for next k iteration
a->add(buffer_A, static_cast<asmjit::Imm>(sizeof(int64_t)));
// update buffer_B address for next k iteration
a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * sizeof(int64_t)));
a->add(B_pf, static_cast<asmjit::Imm>(nBlock * sizeof(int64_t)));
a->cmp(kIdx, kSize);
a->jl(Loopk);
// store C matrix
storeCRegs<instSet>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
// reset A
a->sub(buffer_A, kSize);
// B for next block
a->mov(buffer_B, buffer_B_saved);
// using C_Offset as temp reg
a->imul(
C_Offset,
jIdx,
static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int64_t)));
a->add(buffer_B, C_Offset);
a->mov(B_pf, B_pf_saved);
a->add(B_pf, C_Offset);
// increment C for next B block
a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int64_t)));
int jLoopTrips = currColRegs / maxNRegs;
// jLoopTrips should be at least 1
jLoopTrips = jLoopTrips ? jLoopTrips : 1;
a->cmp(jIdx, jLoopTrips);
a->jl(LoopNBlocks);
// increment A for next block
a->add(
buffer_A,
static_cast<asmjit::Imm>(rowRegs * kBlock * sizeof(int64_t)));
// increment C for next A block
a->sub(
CBase,
static_cast<asmjit::Imm>(
jLoopTrips * nRegBlockSize * sizeof(int64_t)));
a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
a->add(CBase, C_Offset);
// reset B
a->mov(buffer_B, buffer_B_saved);
a->mov(B_pf, B_pf_saved);
a->cmp(iIdx, mRegBlocks);
a->jl(LoopMBlocks);
}
// generate code for remainder
if (mRegBlocksRem > 0) {
assert(false);
asmjit::Label LoopNRem = a->newLabel();
asmjit::Label LoopkRem = a->newLabel();
int rowRegs = mRegBlocksRem;
a->xor_(jIdx.r32(), jIdx.r32());
a->bind(LoopNRem);
a->inc(jIdx);
// init C registers
initCRegs(a, rowRegs, colRegs);
// init k loop index
a->xor_(kIdx.r32(), kIdx.r32());
a->bind(LoopkRem);
// k is incremented by 1
a->add(kIdx, static_cast<asmjit::Imm>(sizeof(int64_t)));
genComputeBlock<instSet>(
a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
// update buffer_A address for next k iteration
a->add(buffer_A, static_cast<asmjit::Imm>(sizeof(int64_t)));
// update buffer_B address for next k iteration
a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * sizeof(int64_t)));
a->add(B_pf, static_cast<asmjit::Imm>(nBlock * sizeof(int64_t)));
a->cmp(kIdx, kSize);
a->jl(LoopkRem);
// reset A
a->sub(buffer_A, kSize);
// B for next block
// using C_Offset as temp reg
a->imul(
C_Offset,
jIdx,
static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int64_t)));
a->mov(buffer_B, buffer_B_saved);
a->add(buffer_B, C_Offset);
a->mov(B_pf, B_pf_saved);
a->add(B_pf, C_Offset);
// store C matrix
storeCRegs<instSet>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
// increment C for next B block
a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int64_t)));
int jLoopTrips = currColRegs / maxNRegs;
// jLoopTrips should be at least 1
jLoopTrips = jLoopTrips ? jLoopTrips : 1;
a->cmp(jIdx, jLoopTrips);
a->jl(LoopNRem);
}
a->emitEpilog(frame);
jit_micro_kernel_fp fn;
asmjit::Error err;
{
unique_lock<mutex> lock(rtMutex_);
err = runtime().add(&fn, &code);
}
if (err) {
cout << "Error: in fn add" << endl;
return nullptr;
}
#ifdef FBGEMM_LOG_CODE
fclose(codeLogfile);
delete codeLogger;
#endif
return fn;
});
}
/**
* Instatiate the AVX512 instructions for int64_t GEMM macro-kernel.
*/
template CodeGenBase<int64_t, int64_t, int64_t, int64_t>::jit_micro_kernel_fp
CodeGenBase<int64_t, int64_t, int64_t, int64_t>::getOrCreate<
inst_set_t::avx512>(bool accum, int32_t mc, int32_t nc, int32_t kc);
// Expected to have overflows
NO_SANITIZE("undefined")
void cblas_gemm_i64_i64acc(
matrix_op_t transa,
matrix_op_t transb,
int M,
int N,
int K,
const int64_t* A,
int lda,
const int64_t* B,
int ldb,
bool accumulate,
int64_t* C,
int ldc) {
cpuinfo_initialize();
if (!fbgemmHasAvx512Support()) {
cblas_gemm_i64_i64acc_ref(
transa, transb, M, N, K, A, lda, B, ldb, accumulate, C, ldc);
return;
}
constexpr int MCB = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::MCB;
constexpr int NCB = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::NCB;
constexpr int KCB = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::KCB;
constexpr int MR = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::MR;
constexpr int NR = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::NR;
static_assert(MCB % MR == 0, "MR must divide MCB");
static_assert(NCB % NR == 0, "NR must divide NCB");
constexpr int VLEN =
simd_info<inst_set_t::avx512>::WIDTH_BYTES / sizeof(int64_t);
static_assert(NR % VLEN == 0, "VLEN must divide NR");
using CodeGenType = CodeGenBase<int64_t, int64_t, int64_t, int64_t>;
CodeGenType codeObj;
CodeGenType::jit_micro_kernel_fp fn =
codeObj.getOrCreate<inst_set_t::avx512>(true /* accum */, MCB, NCB, KCB);
CodeGenType::jit_micro_kernel_fp fn_noacc;
if (!accumulate) {
fn_noacc = codeObj.getOrCreate<inst_set_t::avx512>(
false /* accum */, MCB, NCB, KCB);
}
vector<int64_t> At, Bt;
// TODO: handle transpose during packing
if (transa == matrix_op_t::Transpose) {
At.resize(M * K);
for (int i = 0; i < M; ++i) {
for (int k = 0; k < K; ++k) {
At.at(i * K + k) = A[i + k * lda];
}
}
A = At.data();
lda = K;
}
if (transb == matrix_op_t::Transpose) {
Bt.resize(K * N);
for (int k = 0; k < K; ++k) {
for (int j = 0; j < N; ++j) {
Bt.at(k * N + j) = B[k + j * ldb];
}
}
B = Bt.data();
ldb = N;
}
alignas(64) array<int64_t, MCB * KCB> packA;
alignas(64) array<int64_t, KCB * NCB> packB;
alignas(64) array<int64_t, MCB * NCB> packC;
for (int ic = 0; ic < M; ic += MCB) {
for (int kc = 0; kc < K; kc += KCB) {
// pack A
for (int i = 0; i < std::min(MCB, M - ic); ++i) {
memcpy(
&packA[i * KCB],
A + (ic + i) * lda + kc,
std::min(K - kc, KCB) * sizeof(int64_t));
}
for (int jc = 0; jc < N; jc += NCB) {
// pack B
for (int i = 0; i < std::min(KCB, K - kc); ++i) {
memcpy(
&packB[i * NCB],
B + (kc + i) * ldb + jc,
std::min(NCB, N - jc) * sizeof(int64_t));
}
if (M - ic >= MCB && N - jc >= NCB) {
if (kc == 0 && !accumulate) {
fn_noacc(
packA.data(),
packB.data(),
packB.data(),
C + ic * ldc + jc,
std::min(KCB, K - kc),
ldc);
} else {
fn(packA.data(),
packB.data(),
packB.data(),
C + ic * ldc + jc,
std::min(KCB, K - kc),
ldc);
}
} else {
// remainder
if (kc == 0 && !accumulate) {
fn_noacc(
packA.data(),
packB.data(),
packB.data(),
packC.data(),
std::min(KCB, K - kc),
NCB);
} else {
for (int i = 0; i < std::min(MCB, M - ic); ++i) {
memcpy(
&packC[i * NCB],
C + (ic + i) * ldc + jc,
std::min(NCB, N - jc) * sizeof(int64_t));
}
fn(packA.data(),
packB.data(),
packB.data(),
packC.data(),
std::min(KCB, K - kc),
NCB);
}
for (int i = 0; i < std::min(MCB, M - ic); ++i) {
memcpy(
C + (ic + i) * ldc + jc,
&packC[i * NCB],
std::min(NCB, N - jc) * sizeof(int64_t));
}
}
} // jc
} // kc
} // ic
}
} // namespace fbgemm