bool EmbeddingSpMDM_ref()

in src/RefImplementations.cc [1092:1229]


bool EmbeddingSpMDM_ref(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const InType* input,
    const IndexType* indices,
    const OffsetType* offsets_or_lengths,
    const float* weights, // optional, can be null for non-weighted sum
    bool normalize_by_lengths,
    OutType* out,
    bool is_weight_positional,
    bool use_offsets,
    int64_t output_stride /*=-1*/,
    int64_t input_stride /*=-1*/,
    bool scale_bias_last) {
  bool is8bit = is_same<InType, uint8_t>::value;
  if (output_stride == -1) {
    output_stride = block_size;
  }

  vector<float> buf(block_size);
  if (is8bit) {
    // block_size is the number of elements and fused_block_size is the size of
    // an entire row, including scale and bias.
    if (input_stride == -1) {
      // scale_bias_last == false is for table batched embedding that stores
      // scale and bias in float16
      const auto scale_bias_offset =
          2 * (scale_bias_last ? sizeof(float) : sizeof(float16));
      input_stride = block_size + scale_bias_offset;
    }
    int64_t current = 0;
    for (int m = 0; m < output_size; ++m) {
      memset(buf.data(), 0, sizeof(float) * block_size);
      int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
                            : offsets_or_lengths[m];
      if (current + len > index_size) {
        return false;
      }
      for (int i = 0; i < len; ++i) {
        int64_t idx = indices[current];
        if (idx < 0 || idx >= data_size) {
          return false;
        }

        const float* scale_bias = reinterpret_cast<const float*>(
            input + input_stride * idx + (scale_bias_last ? block_size : 0));

        float weight = 1.0f;
        if (weights) {
          weight = weights[is_weight_positional ? i : current];
        }
        float scale, bias;
        if (scale_bias_last) {
          scale = weight * scale_bias[0];
          bias = weight * scale_bias[1];
        } else {
          scale = weight *
              cpu_half2float(reinterpret_cast<const float16*>(scale_bias)[0]);
          bias = weight *
              cpu_half2float(reinterpret_cast<const float16*>(scale_bias)[1]);
        }

        for (int j = 0; j < block_size; ++j) {
          buf[j] = std::fma(
              scale,
              input
                  [input_stride * idx + j +
                   (scale_bias_last ? 0 : 2 * sizeof(float16))],
              buf[j] + bias);
        }

        ++current;
      }
      if (normalize_by_lengths && len) {
        float scale = 1.f / len;
        for (int j = 0; j < block_size; ++j) {
          buf[j] *= scale;
        }
      }
      for (int j = 0; j < block_size; ++j) {
        out[j] = is_same<OutType, float16>::value ? cpu_float2half_rn(buf[j])
                                                  : buf[j];
      }
      out += output_stride;
    }
    return current == index_size;
  } else {
    if (input_stride == -1) {
      input_stride = block_size;
    }

    // Reference implementation of FP32 SLS
    int64_t current = 0;
    for (int m = 0; m < output_size; ++m) {
      memset(buf.data(), 0, sizeof(float) * block_size);
      int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
                            : offsets_or_lengths[m];
      if (current + len > index_size) {
        return false;
      }
      for (int i = 0; i < len; ++i) {
        int64_t idx = indices[current];
        if (idx < 0 || idx >= data_size) {
          return false;
        }

        float w = 1.f;
        if (weights) {
          w = weights[is_weight_positional ? i : current];
        }

        for (int j = 0; j < block_size; ++j) {
          const InType* inptr = input + input_stride * idx + j;
          buf[j] = std::fma(
              w,
              is_same<InType, float16>::value ? cpu_half2float(*inptr) : *inptr,
              buf[j]);
        }

        ++current;
      }
      if (normalize_by_lengths && len) {
        float scale = 1.f / len;
        for (int j = 0; j < block_size; ++j) {
          buf[j] *= scale;
        }
      }
      for (int j = 0; j < block_size; ++j) {
        out[j] = is_same<OutType, float16>::value ? cpu_float2half_rn(buf[j])
                                                  : buf[j];
      }
      out += output_stride;
    }
    return current == index_size;
  }
}