in tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_fill_empty_rows_op.cu.cc [109:244]
void SparseFillEmptyRowsGpuImpl(OpKernelContext* context,
const int64* input_indices,
const T* input_values, const int64 nnz,
const int64* input_shape,
const T* default_value) {
auto d = context->eigen_gpu_device();
auto OpStream = d.stream();
int64 dense_row_number;
// get the dense shape, which is stored in GPU.
// If the dense shape is already in CPU, we don't need to do the copy here.
cudaMemcpyAsync(&dense_row_number, input_shape, sizeof(int64),
cudaMemcpyDeviceToHost, OpStream);
cudaStreamSynchronize(OpStream);
// temp vector to store start index of each row
Tensor input_row_offset;
Tensor output_row_offset;
Tensor row_nnz_count; // temp buffer for the count kernel, count number of
// non-zero values on each row.
// the size of input_row_offset and output_row_offset is dense_row_number+1,
// because we need one extra place to store the initial value of the offset 0
OP_REQUIRES_OK(context, context->allocate_temp(
DT_INT64, TensorShape({dense_row_number + 1}),
&input_row_offset));
OP_REQUIRES_OK(context, context->allocate_temp(
DT_INT64, TensorShape({dense_row_number + 1}),
&output_row_offset));
OP_REQUIRES_OK(
context, context->allocate_temp(
// use DT_INT32 instead of DT_INT64, because CUDA atomic_add
// only support int32
DT_INT32, TensorShape({dense_row_number}), &row_nnz_count));
cudaMemset(row_nnz_count.flat<int>().data(), 0,
sizeof(int) * dense_row_number);
cudaMemset(input_row_offset.flat<int64>().data(), 0, sizeof(int64));
cudaMemset(output_row_offset.flat<int64>().data(), 0, sizeof(int64));
// Get the number of rows in each row
GpuLaunchConfig count_kernel_config = GetGpuLaunchConfig(nnz, d);
TF_CHECK_OK(GpuLaunchKernel(
SparseFillEmptyRowCountKernel, count_kernel_config.block_count,
count_kernel_config.thread_per_block, 0, d.stream(), input_indices, nnz,
input_shape, row_nnz_count.flat<int>().data(),
input_row_offset.flat<int64>().data(),
output_row_offset.flat<int64>().data()));
/* Calculate the offset of each row of input
* example: the number of rows in each row: [3, 4, 0, 0, 6]
* the offset of each row of input: [0, 3, 7, 7, 7, 13]
*/
// Determine temporary device storage requirements for inclusive prefix sum
size_t temp_storage_bytes = 0;
cub::DeviceScan::InclusiveSum(
NULL, temp_storage_bytes, row_nnz_count.flat<int>().data(),
input_row_offset.flat<int64>().data() + 1, dense_row_number);
// Allocate temporary storage for inclusive prefix sum
Tensor temp_storage;
OP_REQUIRES_OK(
context,
context->allocate_temp(
DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
&temp_storage));
void* d_temp_storage = temp_storage.flat<int8>().data();
// Run inclusive prefix sum
cub::DeviceScan::InclusiveSum(
d_temp_storage, temp_storage_bytes, row_nnz_count.flat<int>().data(),
input_row_offset.flat<int64>().data() + 1, dense_row_number);
/* Add 1 to the row whose row count is 0
* example: the number of rows in each row(row_nnz_count): [3, 4, 0, 0, 6]
* row_nnz_count after the kernel: [3, 4, 1, 1, 6]
*/
GpuLaunchConfig add_kernel_config = GetGpuLaunchConfig(nnz, d);
TF_CHECK_OK(GpuLaunchKernel(
SparseFillEmptyRowAddOneKernel, count_kernel_config.block_count,
count_kernel_config.thread_per_block, 0, d.stream(), input_shape,
row_nnz_count.flat<int>().data()));
// Calculate the offset of each row of output
cub::DeviceScan::InclusiveSum(
d_temp_storage, temp_storage_bytes, row_nnz_count.flat<int>().data(),
output_row_offset.flat<int64>().data() + 1, dense_row_number);
// Read the output size from GPU, which is result of the first kernel.
// copy nnz + num_of_empty_row = output_nnz to CPU
int64 output_nnz;
cudaMemcpyAsync(&output_nnz,
output_row_offset.flat<int64>().data() + dense_row_number,
sizeof(int64), cudaMemcpyDeviceToHost, OpStream);
cudaStreamSynchronize(OpStream);
// Allocate output tensors.
Tensor* output_indices;
Tensor* output_values;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({output_nnz, 2}),
&output_indices));
OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({output_nnz}),
&output_values));
bool* empty_row_indicator = nullptr;
if (context->output_required(2)) {
Tensor* empty_row_indicator_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(2, TensorShape({dense_row_number}),
&empty_row_indicator_t));
empty_row_indicator = empty_row_indicator_t->vec<bool>().data();
// assume row not empty first
cudaMemset(empty_row_indicator, false, sizeof(bool) * dense_row_number);
}
int64* reverse_index_map = nullptr;
if (context->output_required(3)) {
Tensor* reverse_index_map_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(3, TensorShape({nnz}),
&reverse_index_map_t));
reverse_index_map = reverse_index_map_t->vec<int64>().data();
}
// Launch the second Kernel to move data and insert value to empty rows.
GpuLaunchConfig config = GetGpuLaunchConfig(dense_row_number, d);
TF_CHECK_OK(GpuLaunchKernel(
SparseFillEmptyRowFillKernel<T>, config.block_count,
config.thread_per_block, 0, d.stream(), input_indices, input_values,
input_shape, default_value, input_row_offset.flat<int64>().data(),
output_row_offset.flat<int64>().data(),
output_indices->flat<int64>().data(), output_values->flat<T>().data(),
empty_row_indicator, reverse_index_map));
}