in src/Fbgemm.cc [27:206]
void fbgemmPacked(
PackMatrix<
packingAMatrix,
typename packingAMatrix::inpType,
typename packingAMatrix::accType>& packA,
PackMatrix<
packingBMatrix,
typename packingBMatrix::inpType,
typename packingBMatrix::accType>& packB,
cT* C,
int32_t* C_buffer,
uint32_t ldc,
const processOutputType& outProcess,
int thread_id,
int num_threads,
const BlockingFactors* blocking_params) {
static_assert(
std::is_same<
typename packingAMatrix::accType,
typename packingBMatrix::accType>::value,
"Accumulation type of both matrices should be the same");
// Run time CPU detection
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
!fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
throw std::runtime_error("unknown architecure");
}
int64_t MCB;
int KCB;
int MR;
if (blocking_params) {
MCB = blocking_params->MCB;
KCB = blocking_params->KCB;
MR = blocking_params->MR;
} else {
const inst_set_t isa = fbgemmInstructionSet();
switch (isa) {
case inst_set_t::avx512_vnni:
std::tie(MCB, KCB, MR) = PackingTraits<
typename packingAMatrix::inpType,
typename packingAMatrix::accType,
inst_set_t::avx512_vnni>::getCacheBlockParams();
break;
case inst_set_t::avx512_vnni_ymm:
std::tie(MCB, KCB, MR) = PackingTraits<
typename packingAMatrix::inpType,
typename packingAMatrix::accType,
inst_set_t::avx512_vnni_ymm>::getCacheBlockParams();
break;
case inst_set_t::avx512:
std::tie(MCB, KCB, MR) = PackingTraits<
typename packingAMatrix::inpType,
typename packingAMatrix::accType,
inst_set_t::avx512>::getCacheBlockParams();
break;
case inst_set_t::avx512_ymm:
std::tie(MCB, KCB, MR) = PackingTraits<
typename packingAMatrix::inpType,
typename packingAMatrix::accType,
inst_set_t::avx512_ymm>::getCacheBlockParams();
break;
case inst_set_t::avx2:
std::tie(MCB, KCB, MR) = PackingTraits<
typename packingAMatrix::inpType,
typename packingAMatrix::accType,
inst_set_t::avx2>::getCacheBlockParams();
break;
default:
assert(0 && "unknown architecure");
throw std::runtime_error("unknown architecure");
}
}
if (!packB.isPrePacked()) {
throw std::runtime_error("B matrix must be prepacked");
}
int G = packA.numGroups();
if (G != packB.numGroups()) {
throw std::runtime_error(
"A.groups = " + std::to_string(G) + " and B.groups = " +
std::to_string(packB.numGroups()) + " are not the same");
}
int MDim = packA.numRows();
int KDimPerGroup = packB.numRows() / G;
int NDim = packB.numCols();
int kBlocks = (KDimPerGroup + KCB - 1) / KCB;
// remainders
int _kc = KDimPerGroup % KCB;
int kc, mc;
block_type_t blockA{0, 0, 0, 0};
#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
std::chrono::time_point<std::chrono::high_resolution_clock> t_very_start,
t_start, t_end;
double dt;
t_start = std::chrono::high_resolution_clock::now();
t_very_start = std::chrono::high_resolution_clock::now();
#endif
thread_type_t th_info =
fbgemmGetThreadPartition(G, MDim, NDim, thread_id, num_threads);
// if (thread_id == 0)
// std::cout << ", " << th_info.toString();
int64_t g_begin, g_end, i_begin, i_end;
// Calculate the begin and end index along the group dimension
fbgemmPartition1D(
th_info.g_thread_id, th_info.g_num_threads, G, g_begin, g_end);
// Calculate the begin and end index along the m dimension
fbgemmPartition1DBlocked(
th_info.m_thread_id, th_info.m_num_threads, MDim, MR, i_begin, i_end);
for (int g = g_begin; g < g_end; ++g) {
ExecuteKernel<packingAMatrix, packingBMatrix, cT, processOutputType>
exeKernelObj(
packA,
packB,
C,
C_buffer,
ldc,
outProcess,
th_info,
blocking_params);
for (int i = i_begin; i < i_end; i += MCB) { // i is the element index
mc = std::min(i_end - i, MCB);
for (int kb = 0; kb < kBlocks; ++kb) { // kb is the block index
kc = (kb != kBlocks - 1 || _kc == 0) ? KCB : _kc;
// pack A matrix
blockA = {i, mc, g * KDimPerGroup + kb * KCB, kc};
packA.pack(blockA);
#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
t_end = std::chrono::high_resolution_clock::now();
dt = std::chrono::duration_cast<std::chrono::nanoseconds>(
t_end - t_start)
.count();
packing_time += (dt);
t_start = std::chrono::high_resolution_clock::now();
#endif
exeKernelObj.execute(g * kBlocks + kb);
#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
t_end = std::chrono::high_resolution_clock::now();
dt = std::chrono::duration_cast<std::chrono::nanoseconds>(
t_end - t_start)
.count();
computing_time += (dt);
t_start = std::chrono::high_resolution_clock::now();
#endif
}
}
} // for each group
#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
t_end = std::chrono::high_resolution_clock::now();
dt =
std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_very_start)
.count();
run_time += (dt);
t_start = std::chrono::high_resolution_clock::now();
#endif
}