int rowwise_sparse_adagrad_fused_ref()

in src/RefImplementations.cc [1634:1774]


int rowwise_sparse_adagrad_fused_ref(
    int64_t block_size,
    int64_t output_size,
    int64_t index_size,
    int64_t data_size,
    DataType* w,
    const float* g,
    float* h,
    const IndexType* indices,
    const OffsetType* offsets_or_lengths,
    float epsilon,
    float lr,
    bool use_offsets,
    bool use_stochastic_rounding,
    int emu_vector_size,
    int64_t grad_stride) {
  if (grad_stride == -1) {
    grad_stride = block_size;
  }

  constexpr bool isFloat16w = std::is_same<float16, DataType>::value;
  // Local random buffer to emulate SIMD vector
  // R: generated 32bit base random numbers
  // r: extracted 8-bit for rounding
  constexpr int VLEN_MAX = 16;
  uint32_t R[VLEN_MAX], r[VLEN_MAX];
  int vlen = emu_vector_size;
  if (vlen != 8 && vlen != 16) {
    // Raise error as it may cause buffer overflow
    cerr << "Not supported emu_vector_size: " << emu_vector_size << endl;
    return 0;
  }

  int64_t current = 0;
  for (int m = 0; m < output_size; ++m) {
    int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
                          : offsets_or_lengths[m];
    if (current + len > index_size) {
      return false;
    }
    const float* g_ = g + m * grad_stride;
    // Note the following code assumes fbgemm will generate AVX2 code for
    // horizontal reduction, which is OK for now because fbgemm always uses AVX2
    // for SparseAdagrad due to its performance is bounded by memory bandwidth
    // hence no speedup from AVX512.
    // Non-vectorized version would be just
    // for (auto j = 0; j < block_size; ++j) {
    //   float gj = g_[j];
    //   final_sum += gj * gj;
    // }
    constexpr int VLEN_AVX2 = 8;
    array<float, VLEN_AVX2> partial_sum = {0.0f};
    for (auto j = 0; j < block_size; ++j) {
      float gj = g_[j];
      partial_sum[j % VLEN_AVX2] += gj * gj;
    }
    float final_sum = ((partial_sum[0] + partial_sum[1]) +
                       (partial_sum[2] + partial_sum[3])) +
        ((partial_sum[4] + partial_sum[5]) + (partial_sum[6] + partial_sum[7]));
    final_sum /= block_size;

    for (int i = 0; i < len; ++i, ++current) {
      int64_t idx = indices[current];
      if (idx < 0 || idx >= data_size) {
        return false;
      }

      float* h_ = h + idx;
      DataType* w_ = w + idx * block_size;

      float hi = *h_ = *h_ + final_sum;
      float float_step = lr / (std::sqrt(hi) + epsilon);

      int nvec = (block_size + vlen - 1) / vlen;
      int rem = (block_size % vlen) ? (block_size % vlen) : vlen;

      // Emulate JIT behavior of stochastic rounding with vector-length
      //
      // Generate R buffer every 4 steps of nvec loop. Each 8-bit in R
      // (uint32_t) will be used once. It is shifted to bits[5..13] then
      // added to FP32 weights before FP16 conversion.
      //
      // The shifted 8 bit region
      // +-------+--------+--------+--------+
      // |       |        |   xxxxx|xxx     |
      //  31      23       15       7      0
      //
      // Half float has 10 bits of mantissa, and float has 23, we are shifting
      // the bits to cover the region where half floats can't represent data.
      // This is bit 13-23 of the mantissa of fp32.
      // This will be effectively adding a random variable of [0,1]

      for (int n = 0; n < nvec; ++n) {
        int cur_vlen = (n == nvec - 1) ? rem : vlen;
        int sr_idx = n % 4;

        if (isFloat16w && use_stochastic_rounding) {
          if (sr_idx == 0) {
            for (int v = 0; v < vlen; ++v) {
              R[v] = rnd128_next(v, vlen);
              r[v] = (R[v] & 0xFFU) << 5;
            }
          } else if (sr_idx == 1) {
            for (int v = 0; v < vlen; ++v) {
              r[v] = ((R[v] & 0xFF00U) >> 8) << 5;
            }
          } else if (sr_idx == 2) {
            for (int v = 0; v < vlen; ++v) {
              r[v] = ((R[v] & 0xFF0000U) >> 16) << 5;
            }
          } else { // 3
            for (int v = 0; v < vlen; ++v) {
              r[v] = ((R[v] & 0xFF000000U) >> 24) << 5;
            }
          }
        }

        for (int v = 0; v < cur_vlen; ++v) {
          int j = n * vlen + v;
          if (isFloat16w) {
            union {
              float w_f32;
              uint32_t w_i32;
            };
            w_f32 = cpu_half2float(w_[j]);
            w_f32 = std::fma(float_step, g_[j], w_f32);
            if (use_stochastic_rounding) {
              w_i32 += r[v];
            }
            // Use truncate rounding to 'counterwork' the random added part
            w_[j] = cpu_float2half_rz(w_f32);
          } else { // float
            w_[j] += g_[j] * float_step;
          }
        }
      }
    }
  }

  return current == index_size;
}