in src/sim/sim_driver.cc [399:452]
void RunGEMM(const VTAGemInsn* op) {
if (!op->reset_reg) {
prof_->gemm_counter += op->iter_out * op->iter_in * (op->uop_end - op->uop_bgn);
if (prof_->SkipExec()) return;
for (uint32_t y = 0; y < op->iter_out; ++y) {
for (uint32_t x = 0; x < op->iter_in; ++x) {
for (uint32_t uindex = op->uop_bgn; uindex < op->uop_end; ++uindex) {
VTAUop* uop_ptr = static_cast<VTAUop*>(uop_.BeginPtr(uindex));
// Read in memory indices
uint32_t acc_idx = uop_ptr->dst_idx;
uint32_t inp_idx = uop_ptr->src_idx;
uint32_t wgt_idx = uop_ptr->wgt_idx;
acc_idx += y * op->dst_factor_out + x * op->dst_factor_in;
inp_idx += y * op->src_factor_out + x * op->src_factor_in;
wgt_idx += y * op->wgt_factor_out + x * op->wgt_factor_in;
BitPacker<VTA_ACC_WIDTH> acc(acc_.BeginPtr(acc_idx));
BitPacker<VTA_INP_WIDTH> inp(inp_.BeginPtr(inp_idx));
BitPacker<VTA_WGT_WIDTH> wgt(wgt_.BeginPtr(wgt_idx));
// gemm loop
for (uint32_t i = 0; i < VTA_BATCH; ++i) {
for (uint32_t j = 0; j < VTA_BLOCK_OUT; ++j) {
uint32_t acc_offset = i * VTA_BLOCK_OUT + j;
int32_t sum = acc.GetSigned(acc_offset);
for (uint32_t k = 0; k < VTA_BLOCK_IN; ++k) {
sum +=
inp.GetSigned(i * VTA_BLOCK_IN + k) *
wgt.GetSigned(j * VTA_BLOCK_IN + k);
}
acc.SetSigned(acc_offset, sum);
}
}
}
}
}
} else {
if (prof_->SkipExec()) return;
// reset
for (uint32_t y = 0; y < op->iter_out; ++y) {
for (uint32_t x = 0; x < op->iter_in; ++x) {
for (uint32_t uindex = op->uop_bgn; uindex < op->uop_end; ++uindex) {
VTAUop* uop_ptr = static_cast<VTAUop*>(uop_.BeginPtr(uindex));
uint32_t acc_idx = uop_ptr->dst_idx;
acc_idx += y * op->dst_factor_out + x * op->dst_factor_in;
BitPacker<VTA_ACC_WIDTH> acc(acc_.BeginPtr(acc_idx));
for (uint32_t i = 0; i < VTA_BATCH * VTA_BLOCK_OUT; ++i) {
acc.SetSigned(i, 0);
}
}
}
}
}
}