in tensorflow_addons/custom_ops/layers/cc/kernels/embedding_bag_ops.cc [86:156]
void operator()(const CPUDevice &device,
typename TTypes<Tindices, 2>::ConstTensor indices,
typename TTypes<T, 2>::ConstTensor params,
typename TTypes<T, 2>::ConstTensor weights,
typename TTypes<T, 2>::ConstTensor grads,
typename TTypes<T, 2>::Tensor params_grads,
typename TTypes<T, 2>::Tensor weights_grads,
Combiner combiner, OpKernelContext *context) {
const Eigen::Index sequence_length = indices.dimension(1);
const Eigen::Index output_dim = params.dimension(1);
std::unordered_map<Tindices, Eigen::Index> index_map;
// The pair (x, {y_i}) in index_vec means
// index y_i in `indices` contributes to bag `x`.
std::vector<std::pair<Tindices, std::vector<Eigen::Index>>> index_vec;
for (Eigen::Index i = 0; i < indices.size(); ++i) {
Tindices index = indices.data()[i];
if (index_map.find(index) == index_map.end()) {
index_map[index] = index_vec.size();
index_vec.push_back({index, {}});
}
index_vec[index_map[index]].second.push_back(i);
}
const auto compute_params_grads = [&](Eigen::Index start,
Eigen::Index end) {
for (Eigen::Index i = start; i < end; ++i) {
VectorMap params_grads_slice(¶ms_grads(index_vec[i].first, 0),
output_dim);
for (Eigen::Index index : index_vec[i].second) {
const Eigen::Index bag = index / sequence_length;
const Eigen::Index seq = index % sequence_length;
const ConstVectorMap grads_slice(&grads(bag, 0), output_dim);
params_grads_slice += grads_slice * weights(bag, seq);
}
if (combiner == Combiner::kMean) {
params_grads_slice /= static_cast<T>(sequence_length);
}
}
};
const Eigen::Index num_unique_params = index_vec.size();
const double bytes_loaded = 100 * output_dim * sizeof(T);
const double bytes_stored = output_dim * sizeof(T);
const double compute_cycles =
100 * output_dim *
(Eigen::TensorOpCost::AddCost<T>() + Eigen::TensorOpCost::MulCost<T>());
const Eigen::TensorOpCost cost(bytes_loaded, bytes_stored, compute_cycles,
/*vectorized=*/true,
/*packet_size=*/kPacketSize);
params_grads.setZero();
device.parallelFor(num_unique_params, cost,
std::move(compute_params_grads));
const auto compute_weights_grads =
[&](const Eigen::array<Eigen::Index, 2> &coords) -> T {
const Eigen::Index bag = coords[0];
const Eigen::Index seq = coords[1];
const ConstVectorMap grads_slice(&grads(bag, 0), output_dim);
const ConstVectorMap params_slice(¶ms(indices(bag, seq), 0),
output_dim);
T output = params_slice.dot(grads_slice);
if (combiner == Combiner::kMean) {
output /= static_cast<T>(sequence_length);
}
return output;
};
weights_grads.device(device) =
weights_grads.generate(std::move(compute_weights_grads));
}