void SparseFillEmptyRowsGpuImpl()

in tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_fill_empty_rows_op.cu.cc [109:244]


void SparseFillEmptyRowsGpuImpl(OpKernelContext* context,
                                const int64* input_indices,
                                const T* input_values, const int64 nnz,
                                const int64* input_shape,
                                const T* default_value) {
  auto d = context->eigen_gpu_device();
  auto OpStream = d.stream();
  int64 dense_row_number;

  // get the dense shape, which is stored in GPU.
  // If the dense shape is already in CPU, we don't need to do the copy here.
  cudaMemcpyAsync(&dense_row_number, input_shape, sizeof(int64),
                  cudaMemcpyDeviceToHost, OpStream);
  cudaStreamSynchronize(OpStream);

  // temp vector to store start index of each row
  Tensor input_row_offset;
  Tensor output_row_offset;
  Tensor row_nnz_count;  // temp buffer for the count kernel, count number of
                         // non-zero values on each row.

  // the size of input_row_offset and output_row_offset is dense_row_number+1,
  // because we need one extra place to store the initial value of the offset 0
  OP_REQUIRES_OK(context, context->allocate_temp(
                              DT_INT64, TensorShape({dense_row_number + 1}),
                              &input_row_offset));

  OP_REQUIRES_OK(context, context->allocate_temp(
                              DT_INT64, TensorShape({dense_row_number + 1}),
                              &output_row_offset));

  OP_REQUIRES_OK(
      context, context->allocate_temp(
                   // use DT_INT32 instead of DT_INT64, because CUDA atomic_add
                   // only support int32
                   DT_INT32, TensorShape({dense_row_number}), &row_nnz_count));

  cudaMemset(row_nnz_count.flat<int>().data(), 0,
             sizeof(int) * dense_row_number);
  cudaMemset(input_row_offset.flat<int64>().data(), 0, sizeof(int64));
  cudaMemset(output_row_offset.flat<int64>().data(), 0, sizeof(int64));

  // Get the number of rows in each row
  GpuLaunchConfig count_kernel_config = GetGpuLaunchConfig(nnz, d);
  TF_CHECK_OK(GpuLaunchKernel(
      SparseFillEmptyRowCountKernel, count_kernel_config.block_count,
      count_kernel_config.thread_per_block, 0, d.stream(), input_indices, nnz,
      input_shape, row_nnz_count.flat<int>().data(),
      input_row_offset.flat<int64>().data(),
      output_row_offset.flat<int64>().data()));

  /* Calculate the offset of each row of input
   *  example: the number of rows in each row: [3, 4, 0, 0, 6]
   *  the offset of each row of input: [0, 3, 7, 7, 7, 13]
   */
  // Determine temporary device storage requirements for inclusive prefix sum
  size_t temp_storage_bytes = 0;
  cub::DeviceScan::InclusiveSum(
      NULL, temp_storage_bytes, row_nnz_count.flat<int>().data(),
      input_row_offset.flat<int64>().data() + 1, dense_row_number);

  // Allocate temporary storage for inclusive prefix sum
  Tensor temp_storage;
  OP_REQUIRES_OK(
      context,
      context->allocate_temp(
          DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
          &temp_storage));
  void* d_temp_storage = temp_storage.flat<int8>().data();

  // Run inclusive prefix sum
  cub::DeviceScan::InclusiveSum(
      d_temp_storage, temp_storage_bytes, row_nnz_count.flat<int>().data(),
      input_row_offset.flat<int64>().data() + 1, dense_row_number);

  /* Add 1 to the row whose row count is 0
   *  example: the number of rows in each row(row_nnz_count): [3, 4, 0, 0, 6]
   *  row_nnz_count after the kernel: [3, 4, 1, 1, 6]
   */
  GpuLaunchConfig add_kernel_config = GetGpuLaunchConfig(nnz, d);
  TF_CHECK_OK(GpuLaunchKernel(
      SparseFillEmptyRowAddOneKernel, count_kernel_config.block_count,
      count_kernel_config.thread_per_block, 0, d.stream(), input_shape,
      row_nnz_count.flat<int>().data()));

  // Calculate the offset of each row of output
  cub::DeviceScan::InclusiveSum(
      d_temp_storage, temp_storage_bytes, row_nnz_count.flat<int>().data(),
      output_row_offset.flat<int64>().data() + 1, dense_row_number);

  // Read the output size from GPU, which is result of the first kernel.
  // copy nnz + num_of_empty_row = output_nnz to CPU
  int64 output_nnz;
  cudaMemcpyAsync(&output_nnz,
                  output_row_offset.flat<int64>().data() + dense_row_number,
                  sizeof(int64), cudaMemcpyDeviceToHost, OpStream);
  cudaStreamSynchronize(OpStream);

  // Allocate output tensors.
  Tensor* output_indices;
  Tensor* output_values;
  OP_REQUIRES_OK(context,
                 context->allocate_output(0, TensorShape({output_nnz, 2}),
                                          &output_indices));
  OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({output_nnz}),
                                                   &output_values));

  bool* empty_row_indicator = nullptr;
  if (context->output_required(2)) {
    Tensor* empty_row_indicator_t = nullptr;
    OP_REQUIRES_OK(context,
                   context->allocate_output(2, TensorShape({dense_row_number}),
                                            &empty_row_indicator_t));
    empty_row_indicator = empty_row_indicator_t->vec<bool>().data();
    // assume row not empty first
    cudaMemset(empty_row_indicator, false, sizeof(bool) * dense_row_number);
  }

  int64* reverse_index_map = nullptr;
  if (context->output_required(3)) {
    Tensor* reverse_index_map_t = nullptr;
    OP_REQUIRES_OK(context, context->allocate_output(3, TensorShape({nnz}),
                                                     &reverse_index_map_t));
    reverse_index_map = reverse_index_map_t->vec<int64>().data();
  }

  // Launch the second Kernel to move data and insert value to empty rows.
  GpuLaunchConfig config = GetGpuLaunchConfig(dense_row_number, d);
  TF_CHECK_OK(GpuLaunchKernel(
      SparseFillEmptyRowFillKernel<T>, config.block_count,
      config.thread_per_block, 0, d.stream(), input_indices, input_values,
      input_shape, default_value, input_row_offset.flat<int64>().data(),
      output_row_offset.flat<int64>().data(),
      output_indices->flat<int64>().data(), output_values->flat<T>().data(),
      empty_row_indicator, reverse_index_map));
}