Status operator()

in tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc [375:456]


  Status operator()(OpKernelContext *context, const Tensor &input_a_t,
                    const Tensor &input_b_t, const Tensor &topdiff_t,
                    Tensor *output_a_gradient_t, Tensor *output_b_gradient_t,
                    /* params */
                    int kernel_size, int max_displacement, int stride_1,
                    int stride_2, int pad, TensorFormat data_format) {
    // do not change: the CUDA kernels expects THREADS_PER_BLOCK==32
    const int THREADS_PER_BLOCK = 32;

    const int32 N = GetTensorDim(input_a_t, data_format, 'N');
    const int32 iC = GetTensorDim(input_a_t, data_format, 'C');
    const int32 iH = GetTensorDim(input_a_t, data_format, 'H');
    const int32 iW = GetTensorDim(input_a_t, data_format, 'W');

    Tensor padded_a_t;
    Tensor padded_b_t;
    TensorShape padded_shape({N, iH + 2 * pad, iW + 2 * pad, iC});
    TF_RETURN_IF_ERROR(context->allocate_temp(DataTypeToEnum<Dtype>::value,
                                              padded_shape, &padded_a_t));
    TF_RETURN_IF_ERROR(context->allocate_temp(DataTypeToEnum<Dtype>::value,
                                              padded_shape, &padded_b_t));

    dim3 blocks_grid(N, iH, iW);
    dim3 threads_block(THREADS_PER_BLOCK);

    // topdiff is NCHW
    const int32 oC = GetTensorDim(topdiff_t, FORMAT_NCHW, 'C');
    const int32 oH = GetTensorDim(topdiff_t, FORMAT_NCHW, 'H');
    const int32 oW = GetTensorDim(topdiff_t, FORMAT_NCHW, 'W');

    // set everything to zero (we zero-pad)
    cudaMemset(padded_a_t.flat<Dtype>().data(), 0,
               padded_a_t.NumElements() * sizeof(Dtype));
    cudaMemset(padded_b_t.flat<Dtype>().data(), 0,
               padded_b_t.NumElements() * sizeof(Dtype));
    cudaMemset(output_a_gradient_t->flat<Dtype>().data(), 0,
               output_a_gradient_t->NumElements() * sizeof(Dtype));
    cudaMemset(output_b_gradient_t->flat<Dtype>().data(), 0,
               output_b_gradient_t->NumElements() * sizeof(Dtype));

    const bool is_NCHW = (data_format == FORMAT_NCHW);
    if (is_NCHW) {
      pad_and_transpose<THREADS_PER_BLOCK><<<blocks_grid, threads_block>>>(
          input_a_t.flat<Dtype>().data(), padded_a_t.flat<Dtype>().data(), iC,
          iH, iW, pad);
      pad_and_transpose<THREADS_PER_BLOCK><<<blocks_grid, threads_block>>>(
          input_b_t.flat<Dtype>().data(), padded_b_t.flat<Dtype>().data(), iC,
          iH, iW, pad);
    } else {
      pad_and_no_transpose<THREADS_PER_BLOCK><<<blocks_grid, threads_block>>>(
          input_a_t.flat<Dtype>().data(), padded_a_t.flat<Dtype>().data(), iC,
          iH, iW, pad);
      pad_and_no_transpose<THREADS_PER_BLOCK><<<blocks_grid, threads_block>>>(
          input_b_t.flat<Dtype>().data(), padded_b_t.flat<Dtype>().data(), iC,
          iH, iW, pad);
    }

    const GPUDevice &d = context->eigen_gpu_device();

    dim3 threadsPerBlock(THREADS_PER_BLOCK);
    dim3 totalBlocksCorr(iH, iW, iC);

    for (int n = 0; n < N; ++n) {
      Correlation_backward_input1<THREADS_PER_BLOCK>
          <<<totalBlocksCorr, threadsPerBlock>>>(
              n, output_a_gradient_t->flat<Dtype>().data(), iC, iH, iW,
              topdiff_t.flat<Dtype>().data(), oC, oH, oW,
              padded_b_t.flat<Dtype>().data(), pad, kernel_size,
              max_displacement, stride_1, stride_2, is_NCHW);
    }

    for (int n = 0; n < N; n++) {
      Correlation_backward_input2<THREADS_PER_BLOCK>
          <<<totalBlocksCorr, threadsPerBlock>>>(
              n, output_b_gradient_t->flat<Dtype>().data(), iC, iH, iW,
              topdiff_t.flat<Dtype>().data(), oC, oH, oW,
              padded_a_t.flat<Dtype>().data(), pad, kernel_size,
              max_displacement, stride_1, stride_2, is_NCHW);
    }

    return Status::OK();
  }