void GenSparseAdagrad::genRowwiseSparseAdagrad()

in src/SparseAdagrad.cc [230:438]


void GenSparseAdagrad<indxType, instSet>::genRowwiseSparseAdagrad(
    x86::Emitter* a,
    int block_size,
    int unroll_factor,
    int num_vec_regs_per_block,
    int remainder,
    int prefetch,
    typename simd_info<instSet>::vec_reg_t epsilon_vreg,
    typename simd_info<instSet>::vec_reg_t lr_vreg,
    x86::Ymm mask_vreg,
    typename simd_info<instSet>::vec_reg_t temp_vreg,
    typename simd_info<instSet>::vec_reg_t weight_decay_vreg,
    bool has_weight_decay) {
  typedef typename simd_info<instSet>::vec_reg_t vec_reg_t;
  constexpr int vlen = simd_info<instSet>::WIDTH_32BIT_ELEMS;

  // Reduce the unroll factor by 1 for partial sum
  --unroll_factor;
  vec_reg_t partial_sum_vreg = vec_reg_t(unroll_factor);

  if (prefetch) {
    a->prefetchw(x86::dword_ptr(h, temp3_));
  }

  bool areIndices64b = std::is_same<indxType, std::int64_t>::value;
  auto indices_ptr = areIndices64b
      ? x86::qword_ptr(
            indices, temp1_, 3) // use of 3 is to muliply by 8 (int64_t)
      : x86::dword_ptr(
            indices, temp1_, 2); // use of 2 is to muliply by 4 (int32_t)
  if (has_weight_decay) {
    // set base_offset for fetching w in the calculation of gradient square sum
    a->imul(
        areIndices64b ? base_offset : base_offset.r32(),
        indices_ptr,
        static_cast<asmjit::Imm>(block_size * sizeof(float)));
  }

  // Even with avx512, we only need to use avx2 registers when computing
  // partial_sum because some instructions we're using like vhaddps
  // are only in avx2.
  constexpr int vlen_avx2 = simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS;
  int num_vec_regs_per_block_avx2 = (block_size + vlen_avx2 - 1) / vlen_avx2;

  // Use YMM/XMMs with smaller ids for AVX2 specific instructions like vhaddps
  x86::Ymm partial_sum_vreg_avx2(0);
  x86::Xmm partial_sum_xmm0(partial_sum_vreg_avx2.id());

  a->vxorps(
      partial_sum_vreg_avx2, partial_sum_vreg_avx2, partial_sum_vreg_avx2);

  // TODO: need to do a tree-reduction to fully take advantage of unrolling
  for (int vec_idx = 0; vec_idx < num_vec_regs_per_block_avx2;
       vec_idx += unroll_factor - 1) {
    int cur_unroll_factor =
        std::min(unroll_factor - 1, num_vec_regs_per_block_avx2 - vec_idx);
    for (int v = 0; v < cur_unroll_factor; ++v) {
      x86::Ymm out_vreg = x86::Ymm(v + 1);
      if (has_weight_decay && prefetch &&
          ((vec_idx + v) % (64 / (vlen_avx2 * sizeof(float))) == 0)) {
        a->prefetchw(x86::dword_ptr(
            w, temp2_, 0, (vec_idx + v) * vlen_avx2 * sizeof(float)));
      }

      auto g_ptr = x86::dword_ptr(g, (vec_idx + v) * vlen_avx2 * sizeof(float));
      auto w_ptr = x86::dword_ptr(
          w, base_offset, 0, (vec_idx + v) * vlen_avx2 * sizeof(float));
      if (block_size % simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS &&
          vec_idx + v == num_vec_regs_per_block_avx2 - 1) {
        if (instSet == inst_set_t::avx2) {
          a->vmaskmovps(out_vreg, mask_vreg, g_ptr);
          if (has_weight_decay) {
            a->vmaskmovps(temp_vreg.ymm(), mask_vreg, w_ptr);
            a->vfmadd231ps(out_vreg, temp_vreg, weight_decay_vreg);
          }
        } else {
          a->k(reduce_mask_avx512_).z().vmovups(out_vreg, g_ptr);
          if (has_weight_decay) {
            a->k(reduce_mask_avx512_)
                .z()
                .vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr);
          }
        }
      } else {
        a->vmovups(out_vreg, g_ptr);
        if (has_weight_decay) {
          a->vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr);
        }
      }
      a->vmulps(out_vreg, out_vreg, out_vreg);
      a->vaddps(partial_sum_vreg_avx2, partial_sum_vreg_avx2, out_vreg);
    }
  }
  // Reduce sum to 1 value
  // __m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum);
  // __m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2);
  a->vhaddps(
      partial_sum_vreg_avx2, partial_sum_vreg_avx2, partial_sum_vreg_avx2);
  a->vhaddps(
      partial_sum_vreg_avx2, partial_sum_vreg_avx2, partial_sum_vreg_avx2);

  x86::Xmm partial_sum_xmm1(1);

  //_mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3))
  a->movss(partial_sum_xmm1, partial_sum_xmm0);
  //_mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1))
  a->vextractf128(partial_sum_xmm0, partial_sum_vreg_avx2, 1);

  // final_sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) +
  //    _mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1));
  a->addss(partial_sum_xmm0, partial_sum_xmm1);

  // This fragment moves block size (N) to stack and bcasts it to xmm reg
  a->lea(
      x86::rsp,
      x86::dword_ptr(x86::rsp, -1 * static_cast<int>(sizeof(int32_t))));
  a->mov(x86::dword_ptr(x86::rsp), block_size);
  a->vbroadcastss(
      partial_sum_xmm1, x86::dword_ptr(x86::rsp)); // N is partial_sum_xmm1
  a->vcvtdq2ps(partial_sum_xmm1, partial_sum_xmm1);
  a->lea(x86::rsp, x86::dword_ptr(x86::rsp, sizeof(int32_t)));

  if (has_weight_decay) {
    // set base_offset for fetching h
    a->imul(
        areIndices64b ? base_offset : base_offset.r32(),
        indices_ptr,
        static_cast<asmjit::Imm>(sizeof(float)));
  }

  // final_sum /= N
  a->divss(partial_sum_xmm0, partial_sum_xmm1);
  // load h
  a->movss(partial_sum_xmm1, x86::dword_ptr(h, base_offset));
  // *h + final_sum
  a->addss(partial_sum_xmm0, partial_sum_xmm1);
  // store h
  a->movss(x86::dword_ptr(h, base_offset), partial_sum_xmm0);
  // sqrt(hi)
  a->sqrtss(partial_sum_xmm0, partial_sum_xmm0);
  // bcast partial to all of ymm/zmm reg
  a->vpbroadcastd(partial_sum_vreg, partial_sum_xmm0);
  // lr / sqrt(hi) + epsilon
  a->vaddps(partial_sum_vreg, partial_sum_vreg, epsilon_vreg);
  a->vdivps(partial_sum_vreg, lr_vreg, partial_sum_vreg);
  // partial_sum_vreg now has float_step

  // set base_offset for fetching w in updating weights
  a->imul(
      areIndices64b ? base_offset : base_offset.r32(),
      indices_ptr,
      static_cast<asmjit::Imm>(block_size * sizeof(float)));

  for (int vec_idx = 0; vec_idx < num_vec_regs_per_block;
       vec_idx += unroll_factor) {
    int cur_unroll_factor =
        std::min(unroll_factor, num_vec_regs_per_block - vec_idx);

    for (int v = 0; v < cur_unroll_factor; ++v) {
      vec_reg_t out_vreg = vec_reg_t(v);

      if (!has_weight_decay && prefetch &&
          ((vec_idx + v) % (64 / (vlen * sizeof(float))) == 0)) {
        a->prefetchw(
            x86::dword_ptr(w, temp2_, 0, (vec_idx + v) * vlen * sizeof(float)));
      }

      auto g_ptr = x86::dword_ptr(g, (vec_idx + v) * vlen * sizeof(float));
      auto w_ptr = x86::dword_ptr(
          w, base_offset, 0, (vec_idx + v) * vlen * sizeof(float));
      if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
        if (instSet == inst_set_t::avx2) {
          a->vmaskmovps(temp_vreg.ymm(), mask_vreg, g_ptr);
          if (has_weight_decay) {
            a->vmaskmovps(out_vreg.ymm(), mask_vreg, w_ptr);
            // TODO(@taiqing): have vreg for weights
            a->vfmadd231ps(temp_vreg, weight_decay_vreg, out_vreg);
          }
          a->vmulps(temp_vreg, partial_sum_vreg, temp_vreg);

          a->vmaskmovps(out_vreg.ymm(), mask_vreg, w_ptr);
          a->vaddps(out_vreg, temp_vreg, out_vreg);

          a->vmaskmovps(w_ptr, mask_vreg, out_vreg.ymm());
        } else {
          if (has_weight_decay) {
            a->k(x86::k(1)).vmovups(out_vreg, g_ptr);
            a->k(x86::k(1)).vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr);
            a->k(x86::k(1)).vmulps(out_vreg, partial_sum_vreg, out_vreg);
          } else {
            a->k(x86::k(1)).vmulps(out_vreg, partial_sum_vreg, g_ptr);
          }
          a->k(x86::k(1)).vaddps(out_vreg, out_vreg, w_ptr);
          a->k(x86::k(1)).vmovups(w_ptr, out_vreg);
        }
      } else {
        if (has_weight_decay) {
          a->vmovups(out_vreg, g_ptr);
          a->vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr);
          a->vmulps(out_vreg, partial_sum_vreg, out_vreg);
        } else {
          a->vmulps(out_vreg, partial_sum_vreg, g_ptr);
        }
        a->vaddps(out_vreg, out_vreg, w_ptr);
        a->vmovups(w_ptr, out_vreg);
      }
    }
  }
}