in src/SparseAdagrad.cc [230:438]
void GenSparseAdagrad<indxType, instSet>::genRowwiseSparseAdagrad(
x86::Emitter* a,
int block_size,
int unroll_factor,
int num_vec_regs_per_block,
int remainder,
int prefetch,
typename simd_info<instSet>::vec_reg_t epsilon_vreg,
typename simd_info<instSet>::vec_reg_t lr_vreg,
x86::Ymm mask_vreg,
typename simd_info<instSet>::vec_reg_t temp_vreg,
typename simd_info<instSet>::vec_reg_t weight_decay_vreg,
bool has_weight_decay) {
typedef typename simd_info<instSet>::vec_reg_t vec_reg_t;
constexpr int vlen = simd_info<instSet>::WIDTH_32BIT_ELEMS;
// Reduce the unroll factor by 1 for partial sum
--unroll_factor;
vec_reg_t partial_sum_vreg = vec_reg_t(unroll_factor);
if (prefetch) {
a->prefetchw(x86::dword_ptr(h, temp3_));
}
bool areIndices64b = std::is_same<indxType, std::int64_t>::value;
auto indices_ptr = areIndices64b
? x86::qword_ptr(
indices, temp1_, 3) // use of 3 is to muliply by 8 (int64_t)
: x86::dword_ptr(
indices, temp1_, 2); // use of 2 is to muliply by 4 (int32_t)
if (has_weight_decay) {
// set base_offset for fetching w in the calculation of gradient square sum
a->imul(
areIndices64b ? base_offset : base_offset.r32(),
indices_ptr,
static_cast<asmjit::Imm>(block_size * sizeof(float)));
}
// Even with avx512, we only need to use avx2 registers when computing
// partial_sum because some instructions we're using like vhaddps
// are only in avx2.
constexpr int vlen_avx2 = simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS;
int num_vec_regs_per_block_avx2 = (block_size + vlen_avx2 - 1) / vlen_avx2;
// Use YMM/XMMs with smaller ids for AVX2 specific instructions like vhaddps
x86::Ymm partial_sum_vreg_avx2(0);
x86::Xmm partial_sum_xmm0(partial_sum_vreg_avx2.id());
a->vxorps(
partial_sum_vreg_avx2, partial_sum_vreg_avx2, partial_sum_vreg_avx2);
// TODO: need to do a tree-reduction to fully take advantage of unrolling
for (int vec_idx = 0; vec_idx < num_vec_regs_per_block_avx2;
vec_idx += unroll_factor - 1) {
int cur_unroll_factor =
std::min(unroll_factor - 1, num_vec_regs_per_block_avx2 - vec_idx);
for (int v = 0; v < cur_unroll_factor; ++v) {
x86::Ymm out_vreg = x86::Ymm(v + 1);
if (has_weight_decay && prefetch &&
((vec_idx + v) % (64 / (vlen_avx2 * sizeof(float))) == 0)) {
a->prefetchw(x86::dword_ptr(
w, temp2_, 0, (vec_idx + v) * vlen_avx2 * sizeof(float)));
}
auto g_ptr = x86::dword_ptr(g, (vec_idx + v) * vlen_avx2 * sizeof(float));
auto w_ptr = x86::dword_ptr(
w, base_offset, 0, (vec_idx + v) * vlen_avx2 * sizeof(float));
if (block_size % simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS &&
vec_idx + v == num_vec_regs_per_block_avx2 - 1) {
if (instSet == inst_set_t::avx2) {
a->vmaskmovps(out_vreg, mask_vreg, g_ptr);
if (has_weight_decay) {
a->vmaskmovps(temp_vreg.ymm(), mask_vreg, w_ptr);
a->vfmadd231ps(out_vreg, temp_vreg, weight_decay_vreg);
}
} else {
a->k(reduce_mask_avx512_).z().vmovups(out_vreg, g_ptr);
if (has_weight_decay) {
a->k(reduce_mask_avx512_)
.z()
.vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr);
}
}
} else {
a->vmovups(out_vreg, g_ptr);
if (has_weight_decay) {
a->vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr);
}
}
a->vmulps(out_vreg, out_vreg, out_vreg);
a->vaddps(partial_sum_vreg_avx2, partial_sum_vreg_avx2, out_vreg);
}
}
// Reduce sum to 1 value
// __m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum);
// __m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2);
a->vhaddps(
partial_sum_vreg_avx2, partial_sum_vreg_avx2, partial_sum_vreg_avx2);
a->vhaddps(
partial_sum_vreg_avx2, partial_sum_vreg_avx2, partial_sum_vreg_avx2);
x86::Xmm partial_sum_xmm1(1);
//_mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3))
a->movss(partial_sum_xmm1, partial_sum_xmm0);
//_mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1))
a->vextractf128(partial_sum_xmm0, partial_sum_vreg_avx2, 1);
// final_sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) +
// _mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1));
a->addss(partial_sum_xmm0, partial_sum_xmm1);
// This fragment moves block size (N) to stack and bcasts it to xmm reg
a->lea(
x86::rsp,
x86::dword_ptr(x86::rsp, -1 * static_cast<int>(sizeof(int32_t))));
a->mov(x86::dword_ptr(x86::rsp), block_size);
a->vbroadcastss(
partial_sum_xmm1, x86::dword_ptr(x86::rsp)); // N is partial_sum_xmm1
a->vcvtdq2ps(partial_sum_xmm1, partial_sum_xmm1);
a->lea(x86::rsp, x86::dword_ptr(x86::rsp, sizeof(int32_t)));
if (has_weight_decay) {
// set base_offset for fetching h
a->imul(
areIndices64b ? base_offset : base_offset.r32(),
indices_ptr,
static_cast<asmjit::Imm>(sizeof(float)));
}
// final_sum /= N
a->divss(partial_sum_xmm0, partial_sum_xmm1);
// load h
a->movss(partial_sum_xmm1, x86::dword_ptr(h, base_offset));
// *h + final_sum
a->addss(partial_sum_xmm0, partial_sum_xmm1);
// store h
a->movss(x86::dword_ptr(h, base_offset), partial_sum_xmm0);
// sqrt(hi)
a->sqrtss(partial_sum_xmm0, partial_sum_xmm0);
// bcast partial to all of ymm/zmm reg
a->vpbroadcastd(partial_sum_vreg, partial_sum_xmm0);
// lr / sqrt(hi) + epsilon
a->vaddps(partial_sum_vreg, partial_sum_vreg, epsilon_vreg);
a->vdivps(partial_sum_vreg, lr_vreg, partial_sum_vreg);
// partial_sum_vreg now has float_step
// set base_offset for fetching w in updating weights
a->imul(
areIndices64b ? base_offset : base_offset.r32(),
indices_ptr,
static_cast<asmjit::Imm>(block_size * sizeof(float)));
for (int vec_idx = 0; vec_idx < num_vec_regs_per_block;
vec_idx += unroll_factor) {
int cur_unroll_factor =
std::min(unroll_factor, num_vec_regs_per_block - vec_idx);
for (int v = 0; v < cur_unroll_factor; ++v) {
vec_reg_t out_vreg = vec_reg_t(v);
if (!has_weight_decay && prefetch &&
((vec_idx + v) % (64 / (vlen * sizeof(float))) == 0)) {
a->prefetchw(
x86::dword_ptr(w, temp2_, 0, (vec_idx + v) * vlen * sizeof(float)));
}
auto g_ptr = x86::dword_ptr(g, (vec_idx + v) * vlen * sizeof(float));
auto w_ptr = x86::dword_ptr(
w, base_offset, 0, (vec_idx + v) * vlen * sizeof(float));
if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
if (instSet == inst_set_t::avx2) {
a->vmaskmovps(temp_vreg.ymm(), mask_vreg, g_ptr);
if (has_weight_decay) {
a->vmaskmovps(out_vreg.ymm(), mask_vreg, w_ptr);
// TODO(@taiqing): have vreg for weights
a->vfmadd231ps(temp_vreg, weight_decay_vreg, out_vreg);
}
a->vmulps(temp_vreg, partial_sum_vreg, temp_vreg);
a->vmaskmovps(out_vreg.ymm(), mask_vreg, w_ptr);
a->vaddps(out_vreg, temp_vreg, out_vreg);
a->vmaskmovps(w_ptr, mask_vreg, out_vreg.ymm());
} else {
if (has_weight_decay) {
a->k(x86::k(1)).vmovups(out_vreg, g_ptr);
a->k(x86::k(1)).vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr);
a->k(x86::k(1)).vmulps(out_vreg, partial_sum_vreg, out_vreg);
} else {
a->k(x86::k(1)).vmulps(out_vreg, partial_sum_vreg, g_ptr);
}
a->k(x86::k(1)).vaddps(out_vreg, out_vreg, w_ptr);
a->k(x86::k(1)).vmovups(w_ptr, out_vreg);
}
} else {
if (has_weight_decay) {
a->vmovups(out_vreg, g_ptr);
a->vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr);
a->vmulps(out_vreg, partial_sum_vreg, out_vreg);
} else {
a->vmulps(out_vreg, partial_sum_vreg, g_ptr);
}
a->vaddps(out_vreg, out_vreg, w_ptr);
a->vmovups(w_ptr, out_vreg);
}
}
}
}