in src/GenerateI8Depthwise.cc [52:174]
static void genMaddEpi16xNPacked(
x86::Emitter* e,
x86::Ymm a[4],
x86::Gp b,
x86::Ymm c[4],
x86::Ymm* a_sum,
int n,
int remainder,
bool accumulation,
x86::Ymm one_epi8,
x86::Ymm one_epi16,
x86::Ymm zero) {
// Interleave inputs corresponding to 4 filter positions.
// Reuse a[1] and a[3] to save registers
x86::Ymm a01_lo(0), a01_hi(1), a23_lo(a[1]), a23_hi(a[3]);
e->vpunpcklbw(a01_lo, a[0], n == 1 ? zero : a[1]);
if (remainder >= 8) {
e->vpunpckhbw(a01_hi, a[0], n == 1 ? zero : a[1]);
}
if (n > 2) {
e->vpunpcklbw(a23_lo, a[2], n == 3 ? zero : a[3]);
if (remainder >= 8) {
e->vpunpckhbw(a23_hi, a[2], n == 3 ? zero : a[3]);
}
}
// Compute row_wise sum of A for row_offsets
if (a_sum) {
if (accumulation) {
e->vpmaddubsw(a[0], a01_lo, one_epi8);
e->vpaddsw(a_sum[0], a[0], a_sum[0]);
if (remainder >= 8) {
e->vpmaddubsw(a[2], a01_hi, one_epi8);
e->vpaddsw(a_sum[1], a[2], a_sum[1]);
}
} else {
e->vpmaddubsw(a_sum[0], a01_lo, one_epi8);
if (remainder >= 8) {
e->vpmaddubsw(a_sum[1], a01_hi, one_epi8);
}
}
if (n > 2) {
e->vpmaddubsw(a[0], a23_lo, one_epi8);
e->vpaddsw(a_sum[0], a[0], a_sum[0]);
if (remainder >= 8) {
e->vpmaddubsw(a[2], a23_hi, one_epi8);
e->vpaddsw(a_sum[1], a[2], a_sum[1]);
}
}
}
if (n > 2) {
// Reusing a
e->vpunpcklwd(a[0], a01_lo, a23_lo);
e->vpunpckhwd(a[1], a01_lo, a23_lo);
if (remainder >= 16) {
e->vpunpcklwd(a[2], a01_hi, a23_hi);
e->vpunpckhwd(a[3], a01_hi, a23_hi);
}
e->vpmaddubsw(a[0], a[0], x86::ymmword_ptr(b));
e->vpmaddubsw(a[1], a[1], x86::ymmword_ptr(b, 32));
if (remainder >= 16) {
e->vpmaddubsw(a[2], a[2], x86::ymmword_ptr(b, 64));
e->vpmaddubsw(a[3], a[3], x86::ymmword_ptr(b, 96));
}
if (accumulation) {
e->vpmaddwd(a[0], a[0], one_epi16);
e->vpaddd(c[0], c[0], a[0]);
e->vpmaddwd(a[1], a[1], one_epi16);
e->vpaddd(c[1], c[1], a[1]);
if (remainder >= 16) {
e->vpmaddwd(a[2], a[2], one_epi16);
e->vpaddd(c[2], c[2], a[2]);
e->vpmaddwd(a[3], a[3], one_epi16);
e->vpaddd(c[3], c[3], a[3]);
}
} else {
e->vpmaddwd(c[0], a[0], one_epi16);
e->vpmaddwd(c[1], a[1], one_epi16);
if (remainder >= 16) {
e->vpmaddwd(c[2], a[2], one_epi16);
e->vpmaddwd(c[3], a[3], one_epi16);
}
}
} else {
// Reusing a
e->vpmaddubsw(a[0], a01_lo, x86::ymmword_ptr(b));
e->vpmaddubsw(a[1], a01_hi, x86::ymmword_ptr(b, 32));
if (accumulation) {
e->vpmovsxwd(a[2], a[0].half());
e->vpaddd(c[0], c[0], a[2]);
e->vpmovsxwd(a[3], a[1].half());
e->vpaddd(c[1], c[1], a[3]);
if (remainder >= 16) {
e->vextracti128(a[0].half(), a[0], asmjit::Imm(1));
e->vpmovsxwd(a[0], a[0].half());
e->vpaddd(c[2], c[2], a[0]);
e->vextracti128(a[1].half(), a[1], asmjit::Imm(1));
e->vpmovsxwd(a[1], a[1].half());
e->vpaddd(c[3], c[3], a[1]);
}
} else {
e->vpmovsxwd(c[0], a[0].half());
e->vpmovsxwd(c[1], a[1].half());
if (remainder >= 16) {
e->vextracti128(a[0].half(), a[0], asmjit::Imm(1));
e->vpmovsxwd(c[2], a[0].half());
e->vextracti128(a[1].half(), a[1], asmjit::Imm(1));
e->vpmovsxwd(c[3], a[1].half());
}
}
}
}