in src/RefImplementations.cc [1092:1229]
bool EmbeddingSpMDM_ref(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const InType* input,
const IndexType* indices,
const OffsetType* offsets_or_lengths,
const float* weights, // optional, can be null for non-weighted sum
bool normalize_by_lengths,
OutType* out,
bool is_weight_positional,
bool use_offsets,
int64_t output_stride /*=-1*/,
int64_t input_stride /*=-1*/,
bool scale_bias_last) {
bool is8bit = is_same<InType, uint8_t>::value;
if (output_stride == -1) {
output_stride = block_size;
}
vector<float> buf(block_size);
if (is8bit) {
// block_size is the number of elements and fused_block_size is the size of
// an entire row, including scale and bias.
if (input_stride == -1) {
// scale_bias_last == false is for table batched embedding that stores
// scale and bias in float16
const auto scale_bias_offset =
2 * (scale_bias_last ? sizeof(float) : sizeof(float16));
input_stride = block_size + scale_bias_offset;
}
int64_t current = 0;
for (int m = 0; m < output_size; ++m) {
memset(buf.data(), 0, sizeof(float) * block_size);
int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
: offsets_or_lengths[m];
if (current + len > index_size) {
return false;
}
for (int i = 0; i < len; ++i) {
int64_t idx = indices[current];
if (idx < 0 || idx >= data_size) {
return false;
}
const float* scale_bias = reinterpret_cast<const float*>(
input + input_stride * idx + (scale_bias_last ? block_size : 0));
float weight = 1.0f;
if (weights) {
weight = weights[is_weight_positional ? i : current];
}
float scale, bias;
if (scale_bias_last) {
scale = weight * scale_bias[0];
bias = weight * scale_bias[1];
} else {
scale = weight *
cpu_half2float(reinterpret_cast<const float16*>(scale_bias)[0]);
bias = weight *
cpu_half2float(reinterpret_cast<const float16*>(scale_bias)[1]);
}
for (int j = 0; j < block_size; ++j) {
buf[j] = std::fma(
scale,
input
[input_stride * idx + j +
(scale_bias_last ? 0 : 2 * sizeof(float16))],
buf[j] + bias);
}
++current;
}
if (normalize_by_lengths && len) {
float scale = 1.f / len;
for (int j = 0; j < block_size; ++j) {
buf[j] *= scale;
}
}
for (int j = 0; j < block_size; ++j) {
out[j] = is_same<OutType, float16>::value ? cpu_float2half_rn(buf[j])
: buf[j];
}
out += output_stride;
}
return current == index_size;
} else {
if (input_stride == -1) {
input_stride = block_size;
}
// Reference implementation of FP32 SLS
int64_t current = 0;
for (int m = 0; m < output_size; ++m) {
memset(buf.data(), 0, sizeof(float) * block_size);
int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
: offsets_or_lengths[m];
if (current + len > index_size) {
return false;
}
for (int i = 0; i < len; ++i) {
int64_t idx = indices[current];
if (idx < 0 || idx >= data_size) {
return false;
}
float w = 1.f;
if (weights) {
w = weights[is_weight_positional ? i : current];
}
for (int j = 0; j < block_size; ++j) {
const InType* inptr = input + input_stride * idx + j;
buf[j] = std::fma(
w,
is_same<InType, float16>::value ? cpu_half2float(*inptr) : *inptr,
buf[j]);
}
++current;
}
if (normalize_by_lengths && len) {
float scale = 1.f / len;
for (int j = 0; j < block_size; ++j) {
buf[j] *= scale;
}
}
for (int j = 0; j < block_size; ++j) {
out[j] = is_same<OutType, float16>::value ? cpu_float2half_rn(buf[j])
: buf[j];
}
out += output_stride;
}
return current == index_size;
}
}