static llvm::Expected ComputeConvGpuOp()

in backends/gpu/lib/ops/tf/dnn_ops.cc [257:493]


static llvm::Expected<DenseGpuTensor> ComputeConvGpuOp(
    GpuDispatchContext* dctx, const DenseGpuTensor& input,
    const DenseGpuTensor& filter, const OpAttrsRef& attrs,
    const TensorMetadata& result_md) {
  TFRT_ASSIGN_OR_RETURN(auto temp_buffer,
                        AllocateBuffer(dctx, filter.dtype(), filter.shape()));

  TFRT_ASSIGN_OR_RETURN(auto output_buffer,
                        AllocateBuffer(dctx, result_md.dtype, result_md.shape));

  // TODO(b/153682815): Report error if attribute is absent.
  auto padding = attrs.GetStringAsserting("padding");
  auto explicit_paddings = attrs.GetArrayOptional<int>("explicit_paddings");
  auto data_format = attrs.GetStringOptional("data_format");
  auto strides = attrs.GetArrayOptional<Index>("strides");
  auto dilations = attrs.GetArrayOptional<Index>("dilations");

  auto rank = input.shape().GetRank();
  auto channel_order = GetTfChannelOrder(data_format);

  // Determine how to transpose from HWIO to the desired filter layout.
  llvm::SmallVector<Index, 4> transpose;
  transpose.reserve(rank);
  for (int i = 0; i < rank; ++i) transpose.push_back(i);
  switch (channel_order) {
    case ChannelOrder::ChannelFirst:  // HWIO -> OIHW, i.e. {3, 2, 0, 1}
      RotateRight(transpose, 2);
      std::swap(transpose[0], transpose[1]);
      break;
    case ChannelOrder::ChannelLast:  // HWIO -> OHWI, i.e. {3, 0, 1, 2}
      RotateRight(transpose, 1);
      break;
  }

  auto filter_dims_hwio = GetDimensions(filter.shape());
  llvm::SmallVector<Index, 4> filter_dims;  // OIHW or OHWI
  filter_dims.reserve(rank);
  for (auto i : transpose) filter_dims.push_back(filter_dims_hwio[i]);

  auto filter_dims_oihw = filter_dims;
  auto input_dims_nchw = GetDimensions(input.shape());
  auto output_dims_nchw = GetDimensions(result_md.shape);
  if (channel_order == ChannelOrder::ChannelLast) {
    // If layout is NHWC, convert to NCHW.
    RotateRight(llvm::MutableArrayRef<Index>(input_dims_nchw).drop_front());
    RotateRight(llvm::MutableArrayRef<Index>(output_dims_nchw).drop_front());
    RotateRight(llvm::MutableArrayRef<Index>(filter_dims_oihw).drop_front());
  }
  TFRT_ASSIGN_OR_RETURN(
      auto windowed_output_data,
      GetTfWindowedOutputData(input_dims_nchw, filter_dims_oihw, channel_order,
                              padding, explicit_paddings, strides, dilations));

  auto input_ptr = input.buffer().pointer();
  llvm::Optional<DenseGpuTensor> padded_input;
  auto paddings = windowed_output_data.paddings_before;
  // Pad input manually if before and after padding is not the same.
  if (paddings != windowed_output_data.paddings_after) {
    llvm::SmallVector<int64_t, 8> pads_manual;
    pads_manual.reserve(2 * rank);
    // Zero paddings before spatial dimensions.
    pads_manual.resize(channel_order == ChannelOrder::ChannelLast ? 2 : 4, 0);
    for (int i = 0; i < rank - 2; ++i) {
      auto pad_before = windowed_output_data.paddings_before[i];
      auto pad_after = windowed_output_data.paddings_after[i];
      auto difference = pad_before - pad_after;
      pads_manual.push_back(std::max<Index>(0, +difference));
      pads_manual.push_back(std::max<Index>(0, -difference));
      paddings[i] = std::min(pad_before, pad_after);
      // Update input dimensions.
      input_dims_nchw[2 + i] += std::abs(difference);
    }
    // Zero paddings after spatial dimensions.
    pads_manual.resize(2 * rank, 0);

    DenseView pads_manual_view(GetDType<int64_t>(), {rank, 2},
                               pads_manual.data());
    TFRT_ASSIGN_OR_RETURN(
        auto output_metadata,
        TfPadOutputShape<int64_t>(input.metadata(),
                                  pads_manual_view.GetTensor<int64_t, 2>()));
    TFRT_ASSIGN_OR_RETURN(
        auto pad_output,
        CallGpuPadOp(dctx, input, pads_manual_view, output_metadata));
    input_ptr = pad_output.buffer().pointer();
    padded_input.emplace(std::move(pad_output));
  }

  // If image is channels last and filter is 1x1, we may not need to transpose
  // the filter and evaluate a gemm instead of a convolution.
  auto all_equal_to = [](llvm::ArrayRef<Index> array, Index value) {
    return is_splat(array) && array.front() == value;
  };
  if (channel_order == ChannelOrder::ChannelLast &&
      input_dims_nchw[1] == filter_dims_oihw[1] &&  // No grouped convolutions.
      all_equal_to(llvm::makeArrayRef(filter_dims_oihw).drop_front(2), 1) &&
      all_equal_to(windowed_output_data.strides, 1) &&
      all_equal_to(paddings, 0)) {
    auto batch_count = input_dims_nchw[0];
    auto channel_count = input_dims_nchw[1];
    auto pixel_count =
        std::accumulate(input_dims_nchw.begin() + 2, input_dims_nchw.end(), 1,
                        std::multiplies<Index>());
    auto reshaped_input =
        (padded_input ? *padded_input : input)
            .WithShape(TensorShape({batch_count * pixel_count, channel_count}));
    auto reshaped_filter = filter.WithShape(
        TensorShape(llvm::makeArrayRef(filter_dims_hwio).take_back(2)));
    if (auto error = RunCublasGemm(dctx->current_context(), dctx->blas_handle(),
                                   /*transpose_a=*/false, /*transpose_b=*/false,
                                   reshaped_input.getValue(),
                                   reshaped_filter.getValue(), output_buffer)) {
      return std::move(error);
    }
    return DenseGpuTensor(
        result_md.shape, result_md.dtype,
        MakeAvailableAsyncValueRef<GpuBuffer>(std::move(output_buffer)));
  }

  TFRT_ASSIGN_OR_RETURN(
      auto input_data,
      GetTensorDescriptorData(input.dtype(), input_dims_nchw, channel_order));
  TFRT_ASSIGN_OR_RETURN(
      auto filter_data,
      GetTensorDescriptorData(filter.dtype(), filter_dims, channel_order));
  TFRT_ASSIGN_OR_RETURN(auto output_data,
                        GetTensorDescriptorData(
                            result_md.dtype, output_dims_nchw, channel_order));

  TFRT_ASSIGN_OR_RETURN(auto input_desc, CreateTensorDescriptor(input_data));
  TFRT_ASSIGN_OR_RETURN(auto filter_desc,
                        CreateFilterDescriptor(filter_data.dtype, channel_order,
                                               ToIntVec(filter_dims_oihw)));
  TFRT_ASSIGN_OR_RETURN(auto output_desc, CreateTensorDescriptor(output_data));

  auto alpha = ScalingFactor(1.0, input_data.dtype);
  auto beta = ScalingFactor(0.0, output_data.dtype);
  auto platform = dctx->dnn_handle().platform();

  // TODO(iga): Make this function take channel_order instead of FORMAT_IOHW.
  if (auto error =
          TransformFilterTensor(dctx->current_context(), dctx->stream(),
                                channel_order, filter, temp_buffer))
    return std::move(error);

  auto conv_dtype = input_data.dtype;
  // Always use mixed precision for fp16.
  if (conv_dtype == CUDNN_DATA_HALF) conv_dtype = CUDNN_DATA_FLOAT;

  TFRT_ASSIGN_OR_RETURN(auto conv_desc,
                        wrapper::CudnnCreateConvolutionDescriptor());
  if (auto error = wrapper::CudnnSetConvolutionDescriptor(
          conv_desc.get(), ToIntVec(paddings),
          ToIntVec(windowed_output_data.strides),
          ToIntVec(windowed_output_data.dilations), CUDNN_CROSS_CORRELATION,
          conv_dtype))
    return std::move(error);

  // Opt-in to use tensor cores. This might be overwritten below.
  if (auto error = wrapper::CudnnSetConvolutionMathType(conv_desc.get(),
                                                        CUDNN_TENSOR_OP_MATH))
    return std::move(error);

  cudnnConvolutionFwdAlgo_t algo;
  size_t workspace_size_bytes = 0;
  GpuBuffer workspace_buffer;

  // TODO(tfrt-devs): Instead of reading default algorithms from an
  // environment variable, we need to pass these options explicitly through op
  // specific interfaces.
  if (auto default_algo = DefaultCudnnCovolutionForwardAlgorithm()) {
    algo = *default_algo;
    TFRT_ASSIGN_OR_RETURN(
        workspace_size_bytes,
        wrapper::CudnnGetConvolutionForwardWorkspaceSize(
            dctx->dnn_handle(), input_desc.get(), filter_desc.get(),
            conv_desc.get(), output_desc.get(), algo));
  } else {
    auto& map = GetConvolutionForwardAlgorithmMap();
    auto key = std::make_tuple(input_data, filter_data, output_data);
    auto it = map.find(key);
    if (it == map.end()) {
      for (size_t mega_bytes : {1024, 128, 16, 0}) {
        workspace_size_bytes = mega_bytes * 1024 * 1024;
        if (workspace_size_bytes == 0) break;
        if (auto workspace_buffer_or_error = GpuBuffer::Allocate(
                dctx->allocator(), workspace_size_bytes, dctx->stream())) {
          workspace_buffer = std::move(*workspace_buffer_or_error);
          break;
        }
      }
      TFRT_ASSIGN_OR_RETURN(
          auto algo_perfs,
          wrapper::CudnnFindConvolutionForwardAlgorithm(
              dctx->current_context(), dctx->dnn_handle(), input_desc.get(),
              input_ptr, filter_desc.get(), temp_buffer.pointer(),
              conv_desc.get(), output_desc.get(), output_buffer.pointer(), 1,
              workspace_buffer.pointer(), workspace_size_bytes));
      const auto& algo_perf = algo_perfs.front();
      it = map.emplace_hint(it, key,
                            std::make_tuple(algo_perf.algo, algo_perf.memory,
                                            algo_perf.mathType));
    }
    algo = std::get<cudnnConvolutionFwdAlgo_t>(it->second);
    workspace_size_bytes = std::get<size_t>(it->second);
    if (auto error = wrapper::CudnnSetConvolutionMathType(
            conv_desc.get(), std::get<cudnnMathType_t>(it->second)))
      return std::move(error);
  }

  TFRT_ASSIGN_OR_RETURN(
      auto workspace_ptr, [&]() -> llvm::Expected<wrapper::Pointer<void>> {
        if (workspace_size_bytes == 0) {
          return wrapper::Pointer<void>(nullptr, platform);
        }
        if (!workspace_buffer ||
            workspace_buffer.size() < workspace_size_bytes) {
          TFRT_ASSIGN_OR_RETURN(
              workspace_buffer,
              GpuBuffer::Allocate(dctx->allocator(), workspace_size_bytes,
                                  dctx->stream()));
        }
        return workspace_buffer.pointer();
      }());

  if (auto error = wrapper::CudnnConvolutionForward(
          dctx->current_context(), dctx->dnn_handle(), &alpha, input_desc.get(),
          input_ptr, filter_desc.get(), temp_buffer.pointer(), conv_desc.get(),
          algo, workspace_ptr, workspace_size_bytes, &beta, output_desc.get(),
          output_buffer.pointer())) {
    return std::move(error);
  }

  return DenseGpuTensor(
      result_md.shape, result_md.dtype,
      MakeAvailableAsyncValueRef<GpuBuffer>(std::move(output_buffer)));
}