in fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp [25:164]
void split_embedding_forward_cpu_kernel(
Tensor weights,
Tensor weights_offsets,
Tensor D_offsets,
int64_t total_D,
Tensor hash_size_cumsum,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
Tensor indice_weights,
Tensor output) {
int64_t T = D_offsets.numel() - 1;
TORCH_CHECK(T > 0);
// offsets = [T x B + 1]
int64_t B = (offsets.size(0) - 1) / T;
TORCH_CHECK(B >= 0);
TORCH_CHECK(weights.is_contiguous());
indices = indices.contiguous();
offsets = offsets.contiguous();
if (indice_weights.defined()) {
indice_weights = indice_weights.contiguous();
}
const auto D_offsets_data = D_offsets.accessor<int, 1>();
const auto weights_offsets_data = weights_offsets.accessor<int64_t, 1>();
const auto indices_data = indices.data_ptr<int64_t>();
const auto offsets_data = offsets.data_ptr<int64_t>();
const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>();
const auto weights_data = weights.data_ptr<weights_t>();
// If indice_weights not defined, then this accessor won't be used.
// The else condition is just to make compiler happy
const auto indice_weights_data = indice_weights.defined()
? indice_weights.data_ptr<ind_weights_t>()
: nullptr;
auto output_data = output.data_ptr<output_t>();
auto output_stride = output.size(1);
constexpr bool use_fbgemm = (std::is_same<weights_t, float>::value ||
std::is_same<weights_t, at::Half>::value ||
std::is_same<weights_t, uint8_t>::value) &&
std::is_same<output_t, float>::value &&
std::is_same<ind_weights_t, float>::value;
at::parallel_for(0, B, 0, [&](int64_t b_begin, int64_t b_end) {
for (int t = 0; t < T; ++t) {
const auto D_begin = D_offsets_data[t];
const auto D = D_offsets_data[t + 1] - D_offsets_data[t];
const auto table_begin = weights_offsets_data[t];
int64_t hash_size;
int t_temp = t + 1;
do {
hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[t];
++t_temp;
} while (hash_size == 0);
bool success = true;
if (use_fbgemm) {
using fbgemm_weight_t = typename std::conditional<
std::is_same<weights_t, at::Half>::value,
fbgemm::float16,
weights_t>::type;
auto kernel = fbgemm::GenerateEmbeddingSpMDMWithStrides<
fbgemm_weight_t,
/*IndexType=*/int64_t,
/*OffsetType=*/int64_t>(
D,
indice_weights.defined(),
static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN,
/*prefetch=*/16,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
output_stride);
auto offsets_begin_ptr = offsets_data + t * B + b_begin;
auto indices_size = offsets_data[t * B + b_end] - *offsets_begin_ptr;
success = kernel(
b_end - b_begin,
indices_size,
hash_size,
reinterpret_cast<const fbgemm_weight_t*>(
weights_data + table_begin),
indices_data + *offsets_begin_ptr,
offsets_begin_ptr,
indice_weights.defined()
? reinterpret_cast<const float*>(
indice_weights_data + *offsets_begin_ptr)
: nullptr,
reinterpret_cast<float*>(
output_data + b_begin * output_stride + D_begin));
} else {
at::acc_type<output_t, true> output_buf[D];
for (int b = b_begin; b < b_end; ++b) {
const auto pool_begin = offsets_data[t * B + b];
const auto pool_end = offsets_data[t * B + b + 1];
const auto L = pool_end - pool_begin;
memset(output_buf, 0, D * sizeof(at::acc_type<output_t, true>));
for (auto p = pool_begin; p < pool_end; ++p) {
int64_t idx = indices_data[p];
if (idx < 0 || idx >= hash_size) {
success = false;
break;
}
const int64_t embedding_begin = table_begin + idx * D;
for (int64_t d = 0; d < D; ++d) {
output_buf[d] +=
(indice_weights.defined()
? static_cast<at::acc_type<output_t, true>>(
weights_data[embedding_begin + d]) *
static_cast<at::acc_type<output_t, true>>(
indice_weights_data[p])
: static_cast<at::acc_type<output_t, true>>(
weights_data[embedding_begin + d]));
}
}
const double scale_factor =
// NOTE: MEAN pooling will not work with indice_weights!
(static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN &&
!indice_weights.defined() && L > 0)
? 1.0 / L
: 1.0;
for (int d = 0; d < D; ++d) {
output_data[b * output_stride + D_begin + d] =
scale_factor * output_buf[d];
}
if (!success) {
break;
}
} // for each b
} // !use_fbgemm
if (!success) {
fbgemm_gpu::report_embedding_error(
t, B, b_begin, b_end, offsets_data, indices_data, hash_size);
} // !success
} // for each t
}); // parallel for
}