bool EmbeddingSpMDMNBit_ref()

in src/RefImplementations.cc [1232:1316]


bool EmbeddingSpMDMNBit_ref(
    int bit_rate,
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const uint8_t* input,
    const IndexType* indices,
    const OffsetType* offsets_or_lengths,
    const float* weights, // optional, can be null for non-weighted sum
    bool normalize_by_lengths,
    OutType* out,
    bool is_weight_positional,
    bool use_offsets,
    int64_t output_stride,
    int64_t input_stride,
    bool scale_bias_last) {
  assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4");
  int num_elem_per_byte = 8 / bit_rate;

  if (output_stride == -1) {
    output_stride = block_size;
  }

  // block_size is the number of elements and fused_block_size is the size of
  // an entire row, including scale and bias.
  const auto scale_bias_offset = 2 * sizeof(float16);
  if (input_stride == -1) {
    input_stride = (block_size + num_elem_per_byte - 1) / num_elem_per_byte +
        scale_bias_offset;
  }
  int64_t current = 0;
  vector<float> buf(block_size);
  for (int m = 0; m < output_size; ++m) {
    memset(buf.data(), 0, sizeof(float) * block_size);
    int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
                          : offsets_or_lengths[m];
    if (current + len > index_size) {
      return false;
    }
    for (int i = 0; i < len; ++i) {
      int64_t idx = indices[current];
      if (idx < 0 || idx >= data_size) {
        return false;
      }

      const float16* scale_bias = reinterpret_cast<const float16*>(
          input + input_stride * idx +
          (scale_bias_last
               ? (block_size + num_elem_per_byte - 1) / num_elem_per_byte
               : 0));

      float weight = 1.0f;
      if (weights) {
        weight = weights[is_weight_positional ? i : current];
      }
      const float scale = weight * cpu_half2float(scale_bias[0]);
      const float bias = weight * cpu_half2float(scale_bias[1]);

      for (int j = 0; j < block_size; ++j) {
        uint8_t quantized = input
            [input_stride * idx + j / num_elem_per_byte +
             (scale_bias_last ? 0 : scale_bias_offset)];
        quantized >>= (j % num_elem_per_byte) * bit_rate;
        quantized &= (1 << bit_rate) - 1;

        buf[j] = std::fma(scale, quantized, buf[j] + bias);
      }

      ++current;
    }
    if (normalize_by_lengths && len) {
      float scale = 1.f / len;
      for (int j = 0; j < block_size; ++j) {
        buf[j] *= scale;
      }
    }
    for (int j = 0; j < block_size; ++j) {
      out[j] = std::is_same<OutType, float16>::value ? cpu_float2half_rn(buf[j])
                                                     : buf[j];
    }
    out += output_stride;
  }
  return current == index_size;
}