in tensorflow_addons/custom_ops/layers/cc/kernels/embedding_bag_backward_kernels.cu.cc [150:233]
void operator()(const GPUDevice &d,
typename TTypes<Tindices, 2>::ConstTensor indices,
typename TTypes<T, 2>::ConstTensor params,
typename TTypes<T, 2>::ConstTensor weights,
typename TTypes<T, 2>::ConstTensor grads,
typename TTypes<T, 2>::Tensor params_grads,
typename TTypes<T, 2>::Tensor weights_grads,
Combiner combiner, OpKernelContext *context) {
// I copy-pasted this bit from histogram_op_gpu.cu.cc and I sure hope it
// works
tensorflow::AllocatorAttributes gpu_allocator;
gpu_allocator.set_on_host(false);
gpu_allocator.set_gpu_compatible(true);
Tensor sortedIndicesTensor;
Tensor sortedIndicesCounterTensor;
OP_REQUIRES_OK(context,
context->allocate_temp(DataTypeToEnum<Tindices>::value,
TensorShape({indices.size()}),
&sortedIndicesTensor, gpu_allocator));
OP_REQUIRES_OK(context, context->allocate_temp(
DataTypeToEnum<Tindices>::value,
TensorShape({indices.size()}),
&sortedIndicesCounterTensor, gpu_allocator));
auto sortedIndices = sortedIndicesTensor.flat<Tindices>();
auto sortedIndicesCounter = sortedIndicesCounterTensor.flat<Tindices>();
// Note: I tried splitting the two kernels into different streams but
// performance was barely affected.
const Eigen::Index batch_dim = indices.dimension(0);
const Eigen::Index bag_dim = indices.dimension(1);
const Eigen::Index output_dim = params.dimension(1);
const auto params_size = params.size();
const int kThreadsPerBlock = 32;
dim3 gridShape = dim3(batch_dim, bag_dim, 1);
TF_CHECK_OK(GpuLaunchKernel(
EmbeddingBagWeightsGradKernel<T, Tindices, kThreadsPerBlock>, gridShape,
kThreadsPerBlock, 0, d.stream(), output_dim, indices.data(),
params.data(), grads.data(), weights_grads.data(), combiner));
const int indices_size = indices.size();
const int values_size = params.size();
const int total_blocks = Eigen::divup(indices_size, kThreadsPerBlock);
gridShape = dim3(total_blocks, 1, 1);
TF_CHECK_OK(GpuLaunchKernel(
PrepTempArraysKernel<Tindices, kThreadsPerBlock>, gridShape,
kThreadsPerBlock, 0, d.stream(), indices.data(), sortedIndices.data(),
sortedIndicesCounter.data(), indices_size));
thrust::device_ptr<Tindices> sortedIndicesCounterDevicePtr(
sortedIndicesCounter.data());
thrust::device_ptr<Tindices> sortedIndicesDevicePtr(sortedIndices.data());
thrust::device_ptr<T> paramsGradDevicePtr(params_grads.data());
thrust::fill(paramsGradDevicePtr,
paramsGradDevicePtr + static_cast<int>(params_size),
static_cast<T>(0.0f));
thrust::sort_by_key(sortedIndicesDevicePtr,
sortedIndicesDevicePtr + indices_size,
sortedIndicesCounterDevicePtr);
// Handle each row with as few thread blocks as possible
int threadsPerBlock;
int blocksPerRow;
if (output_dim <= MAX_THREADS_PER_BLOCK) {
blocksPerRow = 1;
threadsPerBlock = output_dim;
} else {
blocksPerRow =
Eigen::divup(static_cast<int>(output_dim), MAX_THREADS_PER_BLOCK);
threadsPerBlock =
Eigen::divup(static_cast<int>(output_dim), blocksPerRow);
}
// int blocksPerRow = 1;
// while (threadsPerBlock > MAX_THREADS_PER_BLOCK) {
// threadsPerBlock = (threadsPerBlock + 1) / 2; // Ceiling division
// blocksPerRow *= 2;
// }
gridShape = dim3(indices_size, blocksPerRow, 1);
TF_CHECK_OK(GpuLaunchKernel(
EmbeddingBagValuesGradKernel<T, Tindices>, gridShape, threadsPerBlock,
0, d.stream(), output_dim, bag_dim, sortedIndices.data(),
sortedIndicesCounter.data(), params.data(), weights.data(),
grads.data(), params_grads.data(), combiner));
}