bool EmbeddingSpMDMBlockSize1_()

in src/EmbeddingSpMDMAvx2.cc [18:128]


bool EmbeddingSpMDMBlockSize1_(
    const std::int64_t output_size,
    const std::int64_t index_size,
    const std::int64_t data_size, // the number of rows in input
    const InType* input,
    const IndexType* indices,
    const OffsetType* offsets_or_lengths,
    const float* weights, // optional, can be null for non-weighted sum
    bool normalize_by_lengths,
    float* out,
    bool is_weight_positional,
    bool use_offsets) {
  int64_t current = 0;
  for (int m = 0; m < output_size; ++m) {
    out[m] = 0;
    int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
                          : offsets_or_lengths[m];
    if (current + len > index_size) {
      return false;
    }
    int i = 0;

    // The following code doesn't speedup
#if 0
    constexpr int VLEN = std::is_same<IndexType, std::int64_t>::value ? 4 : 8;
    for (; i < lengths[m] / VLEN * VLEN; i += VLEN) {
      if (std::is_same<IndexType, std::int64_t>::value) {
        __m256i idx_v = _mm256_lddqu_si256(
            reinterpret_cast<const __m256i*>(indices + current));
        // Should be none true
        int mask1 = _mm256_movemask_pd(_mm256_castsi256_pd(
            _mm256_cmpgt_epi64(_mm256_setzero_si256(), idx_v)));
        // Should be all true
        int mask2 = _mm256_movemask_pd(_mm256_castsi256_pd(
            _mm256_cmpgt_epi64(_mm256_set1_epi64x(data_size), idx_v)));
        if (mask1 || mask2 != 0x0f) {
          return false;
        }

        __m128 in_v = _mm256_i64gather_ps(input, idx_v, 4);
        alignas(64) float in_buf[VLEN];
        _mm_store_ps(in_buf, in_v);
        for (int j = 0; j < VLEN; ++j) {
          if (weights) {
            out[m] = std::fma(
                weights[is_weight_positional ? i + j : current + j],
                in_buf[j],
                out[m]);
          } else {
            out[m] += in_buf[j];
          }
        }
      } else {
        __m256i idx_v = _mm256_lddqu_si256(
            reinterpret_cast<const __m256i*>(indices + current));
        // Should be none true
        int mask1 = _mm256_movemask_ps(_mm256_castsi256_ps(
            _mm256_cmpgt_epi32(_mm256_setzero_si256(), idx_v)));
        // Should be all true
        int mask2 = _mm256_movemask_ps(_mm256_castsi256_ps(
            _mm256_cmpgt_epi32(_mm256_set1_epi32(data_size), idx_v)));
        if (mask1 || mask2 != 0x00ff) {
          return false;
        }

        __m256 in_v = _mm256_i32gather_ps(input, idx_v, 4);
        alignas(64) float in_buf[VLEN];
        _mm256_store_ps(in_buf, in_v);
        for (int j = 0; j < VLEN; ++j) {
          if (weights) {
            out[m] = std::fma(
                weights[is_weight_positional ? i + j : current + j],
                in_buf[j],
                out[m]);
          } else {
            out[m] += in_buf[j];
          }
        }
      }

      current += VLEN;
    }
#endif

    for (; i < len; ++i) {
      int64_t idx = indices[current];
      if (idx < 0 || idx >= data_size) {
        return false;
      }

      float w = 1.f;
      if (weights) {
        w = weights[is_weight_positional ? i : current];
      }

      const InType* inptr = input + indices[current];
      out[m] = std::fma(
          w,
          std::is_same<InType, float16>::value ? cpu_half2float(*inptr)
                                               : *inptr,
          out[m]);

      ++current;
    }
    if (normalize_by_lengths && len) {
      float scale = 1.f / len;
      out[m] *= scale;
    }
  }
  return current == index_size;
}