in gemmology.h [1301:1353]
void Engine<Arch>::Shift::Multiply(const uint8_t *A, const int8_t *B,
size_t A_rows, size_t width, size_t B_cols,
Callback callback, ExecutionEngine& engine) {
using batch8 = xsimd::batch<int8_t, Arch>;
using ubatch8 = xsimd::batch<uint8_t, Arch>;
using batch32 = xsimd::batch<int32_t, Arch>;
engine(0, B_cols, 8, [A, B, A_rows, width, B_cols, &callback](size_t B0_colidx) {
const size_t simd_width = width / batch8::size;
const auto *B0_col =
reinterpret_cast<const batch8 *>(B) + simd_width * B0_colidx;
/* Process one row of A at a time. Doesn't seem to be faster to do multiple
* rows of A at once.*/
for (size_t A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
const auto *A_row =
reinterpret_cast<const ubatch8 *>(A + A_rowidx * width);
/* These will be packed 16-bit integers containing sums for each row of B
multiplied by the row of A. Iterate over shared (inner) dimension.*/
/* Upcast to 32-bit and horizontally add. Seems a bit faster if this is
* declared here.*/
size_t k = 0;
ubatch8 a = *(A_row + k);
batch32 isum0 = maddw(a, *(B0_col + k * 8));
batch32 isum1 = maddw(a, *(B0_col + k * 8 + 1));
batch32 isum2 = maddw(a, *(B0_col + k * 8 + 2));
batch32 isum3 = maddw(a, *(B0_col + k * 8 + 3));
batch32 isum4 = maddw(a, *(B0_col + k * 8 + 4));
batch32 isum5 = maddw(a, *(B0_col + k * 8 + 5));
batch32 isum6 = maddw(a, *(B0_col + k * 8 + 6));
batch32 isum7 = maddw(a, *(B0_col + k * 8 + 7));
for (k = 1; k < simd_width; ++k) {
a = *(A_row + k);
/* Multiply 8-bit, horizontally add to packed 16-bit integers.*/
/* Upcast to 32-bit and horizontally add.*/
isum0 = maddw(a, *(B0_col + k * 8 + 0), isum0);
isum1 = maddw(a, *(B0_col + k * 8 + 1), isum1);
isum2 = maddw(a, *(B0_col + k * 8 + 2), isum2);
isum3 = maddw(a, *(B0_col + k * 8 + 3), isum3);
isum4 = maddw(a, *(B0_col + k * 8 + 4), isum4);
isum5 = maddw(a, *(B0_col + k * 8 + 5), isum5);
isum6 = maddw(a, *(B0_col + k * 8 + 6), isum6);
isum7 = maddw(a, *(B0_col + k * 8 + 7), isum7);
}
/* Reduce sums within 128-bit lanes.*/
auto pack0123 = Pack0123(isum0, isum1, isum2, isum3);
auto pack4567 = Pack0123(isum4, isum5, isum6, isum7);
/*The specific implementation may need to reduce further.*/
auto total = PermuteSummer(pack0123, pack4567);
callback(total, A_rowidx, B0_colidx, B_cols);
}
});
}