void operator()

in tensorflow_addons/custom_ops/layers/cc/kernels/embedding_bag_backward_kernels.cu.cc [150:233]


  void operator()(const GPUDevice &d,
                  typename TTypes<Tindices, 2>::ConstTensor indices,
                  typename TTypes<T, 2>::ConstTensor params,
                  typename TTypes<T, 2>::ConstTensor weights,
                  typename TTypes<T, 2>::ConstTensor grads,
                  typename TTypes<T, 2>::Tensor params_grads,
                  typename TTypes<T, 2>::Tensor weights_grads,
                  Combiner combiner, OpKernelContext *context) {
    // I copy-pasted this bit from histogram_op_gpu.cu.cc and I sure hope it
    // works
    tensorflow::AllocatorAttributes gpu_allocator;
    gpu_allocator.set_on_host(false);
    gpu_allocator.set_gpu_compatible(true);

    Tensor sortedIndicesTensor;
    Tensor sortedIndicesCounterTensor;

    OP_REQUIRES_OK(context,
                   context->allocate_temp(DataTypeToEnum<Tindices>::value,
                                          TensorShape({indices.size()}),
                                          &sortedIndicesTensor, gpu_allocator));
    OP_REQUIRES_OK(context, context->allocate_temp(
                                DataTypeToEnum<Tindices>::value,
                                TensorShape({indices.size()}),
                                &sortedIndicesCounterTensor, gpu_allocator));
    auto sortedIndices = sortedIndicesTensor.flat<Tindices>();
    auto sortedIndicesCounter = sortedIndicesCounterTensor.flat<Tindices>();
    // Note: I tried splitting the two kernels into different streams but
    // performance was barely affected.
    const Eigen::Index batch_dim = indices.dimension(0);
    const Eigen::Index bag_dim = indices.dimension(1);
    const Eigen::Index output_dim = params.dimension(1);
    const auto params_size = params.size();
    const int kThreadsPerBlock = 32;
    dim3 gridShape = dim3(batch_dim, bag_dim, 1);
    TF_CHECK_OK(GpuLaunchKernel(
        EmbeddingBagWeightsGradKernel<T, Tindices, kThreadsPerBlock>, gridShape,
        kThreadsPerBlock, 0, d.stream(), output_dim, indices.data(),
        params.data(), grads.data(), weights_grads.data(), combiner));

    const int indices_size = indices.size();
    const int values_size = params.size();
    const int total_blocks = Eigen::divup(indices_size, kThreadsPerBlock);
    gridShape = dim3(total_blocks, 1, 1);

    TF_CHECK_OK(GpuLaunchKernel(
        PrepTempArraysKernel<Tindices, kThreadsPerBlock>, gridShape,
        kThreadsPerBlock, 0, d.stream(), indices.data(), sortedIndices.data(),
        sortedIndicesCounter.data(), indices_size));

    thrust::device_ptr<Tindices> sortedIndicesCounterDevicePtr(
        sortedIndicesCounter.data());
    thrust::device_ptr<Tindices> sortedIndicesDevicePtr(sortedIndices.data());
    thrust::device_ptr<T> paramsGradDevicePtr(params_grads.data());
    thrust::fill(paramsGradDevicePtr,
                 paramsGradDevicePtr + static_cast<int>(params_size),
                 static_cast<T>(0.0f));
    thrust::sort_by_key(sortedIndicesDevicePtr,
                        sortedIndicesDevicePtr + indices_size,
                        sortedIndicesCounterDevicePtr);
    // Handle each row with as few thread blocks as possible
    int threadsPerBlock;
    int blocksPerRow;
    if (output_dim <= MAX_THREADS_PER_BLOCK) {
      blocksPerRow = 1;
      threadsPerBlock = output_dim;
    } else {
      blocksPerRow =
          Eigen::divup(static_cast<int>(output_dim), MAX_THREADS_PER_BLOCK);
      threadsPerBlock =
          Eigen::divup(static_cast<int>(output_dim), blocksPerRow);
    }
    // int blocksPerRow = 1;
    // while (threadsPerBlock > MAX_THREADS_PER_BLOCK) {
    //   threadsPerBlock = (threadsPerBlock + 1) / 2;  // Ceiling division
    //   blocksPerRow *= 2;
    // }
    gridShape = dim3(indices_size, blocksPerRow, 1);
    TF_CHECK_OK(GpuLaunchKernel(
        EmbeddingBagValuesGradKernel<T, Tindices>, gridShape, threadsPerBlock,
        0, d.stream(), output_dim, bag_dim, sortedIndices.data(),
        sortedIndicesCounter.data(), params.data(), weights.data(),
        grads.data(), params_grads.data(), combiner));
  }