static void genMaddEpi16xNPacked()

in src/GenerateI8Depthwise.cc [52:174]


static void genMaddEpi16xNPacked(
    x86::Emitter* e,
    x86::Ymm a[4],
    x86::Gp b,
    x86::Ymm c[4],
    x86::Ymm* a_sum,
    int n,
    int remainder,
    bool accumulation,
    x86::Ymm one_epi8,
    x86::Ymm one_epi16,
    x86::Ymm zero) {
  // Interleave inputs corresponding to 4 filter positions.
  // Reuse a[1] and a[3] to save registers
  x86::Ymm a01_lo(0), a01_hi(1), a23_lo(a[1]), a23_hi(a[3]);
  e->vpunpcklbw(a01_lo, a[0], n == 1 ? zero : a[1]);
  if (remainder >= 8) {
    e->vpunpckhbw(a01_hi, a[0], n == 1 ? zero : a[1]);
  }
  if (n > 2) {
    e->vpunpcklbw(a23_lo, a[2], n == 3 ? zero : a[3]);
    if (remainder >= 8) {
      e->vpunpckhbw(a23_hi, a[2], n == 3 ? zero : a[3]);
    }
  }

  // Compute row_wise sum of A for row_offsets
  if (a_sum) {
    if (accumulation) {
      e->vpmaddubsw(a[0], a01_lo, one_epi8);
      e->vpaddsw(a_sum[0], a[0], a_sum[0]);

      if (remainder >= 8) {
        e->vpmaddubsw(a[2], a01_hi, one_epi8);
        e->vpaddsw(a_sum[1], a[2], a_sum[1]);
      }
    } else {
      e->vpmaddubsw(a_sum[0], a01_lo, one_epi8);
      if (remainder >= 8) {
        e->vpmaddubsw(a_sum[1], a01_hi, one_epi8);
      }
    }

    if (n > 2) {
      e->vpmaddubsw(a[0], a23_lo, one_epi8);
      e->vpaddsw(a_sum[0], a[0], a_sum[0]);

      if (remainder >= 8) {
        e->vpmaddubsw(a[2], a23_hi, one_epi8);
        e->vpaddsw(a_sum[1], a[2], a_sum[1]);
      }
    }
  }

  if (n > 2) {
    // Reusing a
    e->vpunpcklwd(a[0], a01_lo, a23_lo);
    e->vpunpckhwd(a[1], a01_lo, a23_lo);
    if (remainder >= 16) {
      e->vpunpcklwd(a[2], a01_hi, a23_hi);
      e->vpunpckhwd(a[3], a01_hi, a23_hi);
    }

    e->vpmaddubsw(a[0], a[0], x86::ymmword_ptr(b));
    e->vpmaddubsw(a[1], a[1], x86::ymmword_ptr(b, 32));
    if (remainder >= 16) {
      e->vpmaddubsw(a[2], a[2], x86::ymmword_ptr(b, 64));
      e->vpmaddubsw(a[3], a[3], x86::ymmword_ptr(b, 96));
    }

    if (accumulation) {
      e->vpmaddwd(a[0], a[0], one_epi16);
      e->vpaddd(c[0], c[0], a[0]);
      e->vpmaddwd(a[1], a[1], one_epi16);
      e->vpaddd(c[1], c[1], a[1]);

      if (remainder >= 16) {
        e->vpmaddwd(a[2], a[2], one_epi16);
        e->vpaddd(c[2], c[2], a[2]);
        e->vpmaddwd(a[3], a[3], one_epi16);
        e->vpaddd(c[3], c[3], a[3]);
      }
    } else {
      e->vpmaddwd(c[0], a[0], one_epi16);
      e->vpmaddwd(c[1], a[1], one_epi16);

      if (remainder >= 16) {
        e->vpmaddwd(c[2], a[2], one_epi16);
        e->vpmaddwd(c[3], a[3], one_epi16);
      }
    }
  } else {
    // Reusing a
    e->vpmaddubsw(a[0], a01_lo, x86::ymmword_ptr(b));
    e->vpmaddubsw(a[1], a01_hi, x86::ymmword_ptr(b, 32));

    if (accumulation) {
      e->vpmovsxwd(a[2], a[0].half());
      e->vpaddd(c[0], c[0], a[2]);
      e->vpmovsxwd(a[3], a[1].half());
      e->vpaddd(c[1], c[1], a[3]);

      if (remainder >= 16) {
        e->vextracti128(a[0].half(), a[0], asmjit::Imm(1));
        e->vpmovsxwd(a[0], a[0].half());
        e->vpaddd(c[2], c[2], a[0]);
        e->vextracti128(a[1].half(), a[1], asmjit::Imm(1));
        e->vpmovsxwd(a[1], a[1].half());
        e->vpaddd(c[3], c[3], a[1]);
      }
    } else {
      e->vpmovsxwd(c[0], a[0].half());
      e->vpmovsxwd(c[1], a[1].half());

      if (remainder >= 16) {
        e->vextracti128(a[0].half(), a[0], asmjit::Imm(1));
        e->vpmovsxwd(c[2], a[0].half());
        e->vextracti128(a[1].half(), a[1], asmjit::Imm(1));
        e->vpmovsxwd(c[3], a[1].half());
      }
    }
  }
}