in src/RefImplementations.cc [1232:1316]
bool EmbeddingSpMDMNBit_ref(
int bit_rate,
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const uint8_t* input,
const IndexType* indices,
const OffsetType* offsets_or_lengths,
const float* weights, // optional, can be null for non-weighted sum
bool normalize_by_lengths,
OutType* out,
bool is_weight_positional,
bool use_offsets,
int64_t output_stride,
int64_t input_stride,
bool scale_bias_last) {
assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4");
int num_elem_per_byte = 8 / bit_rate;
if (output_stride == -1) {
output_stride = block_size;
}
// block_size is the number of elements and fused_block_size is the size of
// an entire row, including scale and bias.
const auto scale_bias_offset = 2 * sizeof(float16);
if (input_stride == -1) {
input_stride = (block_size + num_elem_per_byte - 1) / num_elem_per_byte +
scale_bias_offset;
}
int64_t current = 0;
vector<float> buf(block_size);
for (int m = 0; m < output_size; ++m) {
memset(buf.data(), 0, sizeof(float) * block_size);
int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
: offsets_or_lengths[m];
if (current + len > index_size) {
return false;
}
for (int i = 0; i < len; ++i) {
int64_t idx = indices[current];
if (idx < 0 || idx >= data_size) {
return false;
}
const float16* scale_bias = reinterpret_cast<const float16*>(
input + input_stride * idx +
(scale_bias_last
? (block_size + num_elem_per_byte - 1) / num_elem_per_byte
: 0));
float weight = 1.0f;
if (weights) {
weight = weights[is_weight_positional ? i : current];
}
const float scale = weight * cpu_half2float(scale_bias[0]);
const float bias = weight * cpu_half2float(scale_bias[1]);
for (int j = 0; j < block_size; ++j) {
uint8_t quantized = input
[input_stride * idx + j / num_elem_per_byte +
(scale_bias_last ? 0 : scale_bias_offset)];
quantized >>= (j % num_elem_per_byte) * bit_rate;
quantized &= (1 << bit_rate) - 1;
buf[j] = std::fma(scale, quantized, buf[j] + bias);
}
++current;
}
if (normalize_by_lengths && len) {
float scale = 1.f / len;
for (int j = 0; j < block_size; ++j) {
buf[j] *= scale;
}
}
for (int j = 0; j < block_size; ++j) {
out[j] = std::is_same<OutType, float16>::value ? cpu_float2half_rn(buf[j])
: buf[j];
}
out += output_stride;
}
return current == index_size;
}