in src/RefImplementations.cc [1319:1440]
bool EmbeddingSpMDMRowWiseSparse_ref(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t uncompressed_data_size,
// const int64_t compressed_data_size,
const InType* input,
const IndexType* indices,
const int32_t* compressed_indices_table,
const OffsetType* offsets_or_lengths,
const float* weights, // optional, can be null for non-weighted sum
bool normalize_by_lengths,
float* out,
bool is_weight_positional,
bool use_offsets) {
bool is8bit = is_same<InType, uint8_t>::value;
if (is8bit) {
// block_size is the number of elements and fused_block_size is the size of
// an entire row, including scale and bias.
const auto scale_bias_offset = 2 * sizeof(float);
const int64_t fused_block_size = block_size + scale_bias_offset;
int64_t current = 0;
for (int m = 0; m < output_size; ++m) {
memset(out, 0, sizeof(float) * block_size);
int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
: offsets_or_lengths[m];
if (current + len > index_size) {
return false;
}
for (int i = 0; i < len; ++i) {
IndexType uncompressed_idx = indices[current];
if (uncompressed_idx < 0 ||
uncompressed_idx >= uncompressed_data_size) {
return false;
}
IndexType idx = compressed_indices_table[uncompressed_idx];
if (idx == -1) {
++current;
continue;
}
// if (idx < 0 || idx >= compressed_data_size) {
// return false;
// }
const float* scale_bias = reinterpret_cast<const float*>(
input + fused_block_size * idx + block_size);
float weight = 1.0f;
if (weights) {
weight = weights[is_weight_positional ? i : current];
}
const float scale = weight * scale_bias[0];
const float bias = weight * scale_bias[1];
for (int j = 0; j < block_size; ++j) {
out[j] =
std::fma(scale, input[fused_block_size * idx + j], out[j] + bias);
}
++current;
}
if (normalize_by_lengths && len) {
float scale = 1.f / len;
for (int j = 0; j < block_size; ++j) {
out[j] *= scale;
}
}
out += block_size;
}
return current == index_size;
} else {
// Reference implementation of FP32 SLS
int64_t current = 0;
for (int m = 0; m < output_size; ++m) {
memset(out, 0, sizeof(float) * block_size);
int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
: offsets_or_lengths[m];
if (current + len > index_size) {
return false;
}
for (int i = 0; i < len; ++i) {
IndexType uncompressed_idx = indices[current];
if (uncompressed_idx < 0 ||
uncompressed_idx >= uncompressed_data_size) {
return false;
}
IndexType idx = compressed_indices_table[uncompressed_idx];
if (idx == -1) {
++current;
continue;
}
// if (idx < 0 || idx >= compressed_data_size) {
// return false;
// }
float w = 1.f;
if (weights) {
w = weights[is_weight_positional ? i : current];
}
for (int j = 0; j < block_size; ++j) {
const InType* inptr = input + block_size * idx + j;
out[j] = std::fma(
w,
is_same<InType, float16>::value ? cpu_half2float(*inptr) : *inptr,
out[j]);
}
++current;
}
if (normalize_by_lengths && len) {
float scale = 1.f / len;
for (int j = 0; j < block_size; ++j) {
out[j] *= scale;
}
}
out += block_size;
}
return current == index_size;
}
}