in src/RefImplementations.cc [1634:1774]
int rowwise_sparse_adagrad_fused_ref(
int64_t block_size,
int64_t output_size,
int64_t index_size,
int64_t data_size,
DataType* w,
const float* g,
float* h,
const IndexType* indices,
const OffsetType* offsets_or_lengths,
float epsilon,
float lr,
bool use_offsets,
bool use_stochastic_rounding,
int emu_vector_size,
int64_t grad_stride) {
if (grad_stride == -1) {
grad_stride = block_size;
}
constexpr bool isFloat16w = std::is_same<float16, DataType>::value;
// Local random buffer to emulate SIMD vector
// R: generated 32bit base random numbers
// r: extracted 8-bit for rounding
constexpr int VLEN_MAX = 16;
uint32_t R[VLEN_MAX], r[VLEN_MAX];
int vlen = emu_vector_size;
if (vlen != 8 && vlen != 16) {
// Raise error as it may cause buffer overflow
cerr << "Not supported emu_vector_size: " << emu_vector_size << endl;
return 0;
}
int64_t current = 0;
for (int m = 0; m < output_size; ++m) {
int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
: offsets_or_lengths[m];
if (current + len > index_size) {
return false;
}
const float* g_ = g + m * grad_stride;
// Note the following code assumes fbgemm will generate AVX2 code for
// horizontal reduction, which is OK for now because fbgemm always uses AVX2
// for SparseAdagrad due to its performance is bounded by memory bandwidth
// hence no speedup from AVX512.
// Non-vectorized version would be just
// for (auto j = 0; j < block_size; ++j) {
// float gj = g_[j];
// final_sum += gj * gj;
// }
constexpr int VLEN_AVX2 = 8;
array<float, VLEN_AVX2> partial_sum = {0.0f};
for (auto j = 0; j < block_size; ++j) {
float gj = g_[j];
partial_sum[j % VLEN_AVX2] += gj * gj;
}
float final_sum = ((partial_sum[0] + partial_sum[1]) +
(partial_sum[2] + partial_sum[3])) +
((partial_sum[4] + partial_sum[5]) + (partial_sum[6] + partial_sum[7]));
final_sum /= block_size;
for (int i = 0; i < len; ++i, ++current) {
int64_t idx = indices[current];
if (idx < 0 || idx >= data_size) {
return false;
}
float* h_ = h + idx;
DataType* w_ = w + idx * block_size;
float hi = *h_ = *h_ + final_sum;
float float_step = lr / (std::sqrt(hi) + epsilon);
int nvec = (block_size + vlen - 1) / vlen;
int rem = (block_size % vlen) ? (block_size % vlen) : vlen;
// Emulate JIT behavior of stochastic rounding with vector-length
//
// Generate R buffer every 4 steps of nvec loop. Each 8-bit in R
// (uint32_t) will be used once. It is shifted to bits[5..13] then
// added to FP32 weights before FP16 conversion.
//
// The shifted 8 bit region
// +-------+--------+--------+--------+
// | | | xxxxx|xxx |
// 31 23 15 7 0
//
// Half float has 10 bits of mantissa, and float has 23, we are shifting
// the bits to cover the region where half floats can't represent data.
// This is bit 13-23 of the mantissa of fp32.
// This will be effectively adding a random variable of [0,1]
for (int n = 0; n < nvec; ++n) {
int cur_vlen = (n == nvec - 1) ? rem : vlen;
int sr_idx = n % 4;
if (isFloat16w && use_stochastic_rounding) {
if (sr_idx == 0) {
for (int v = 0; v < vlen; ++v) {
R[v] = rnd128_next(v, vlen);
r[v] = (R[v] & 0xFFU) << 5;
}
} else if (sr_idx == 1) {
for (int v = 0; v < vlen; ++v) {
r[v] = ((R[v] & 0xFF00U) >> 8) << 5;
}
} else if (sr_idx == 2) {
for (int v = 0; v < vlen; ++v) {
r[v] = ((R[v] & 0xFF0000U) >> 16) << 5;
}
} else { // 3
for (int v = 0; v < vlen; ++v) {
r[v] = ((R[v] & 0xFF000000U) >> 24) << 5;
}
}
}
for (int v = 0; v < cur_vlen; ++v) {
int j = n * vlen + v;
if (isFloat16w) {
union {
float w_f32;
uint32_t w_i32;
};
w_f32 = cpu_half2float(w_[j]);
w_f32 = std::fma(float_step, g_[j], w_f32);
if (use_stochastic_rounding) {
w_i32 += r[v];
}
// Use truncate rounding to 'counterwork' the random added part
w_[j] = cpu_float2half_rz(w_f32);
} else { // float
w_[j] += g_[j] * float_step;
}
}
}
}
}
return current == index_size;
}