void Compute()

in tensorflow/tensorflow/core/kernels/conv_grad_ops_3d.cc [1145:1515]


  void Compute(OpKernelContext* context) override {
    const Tensor& filter = context->input(1);
    const TensorShape& filter_shape = filter.shape();

    const Tensor& out_backprop = context->input(2);
    const TensorShape& out_backprop_shape = out_backprop.shape();

    TensorShape input_shape;
    if (takes_shape_) {
      const Tensor& input_sizes = context->input(0);
      OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape));
    } else {
      input_shape = context->input(0).shape();
    }

    ConvBackpropDimensions dims;
    OP_REQUIRES_OK(context, ConvBackpropComputeDimensionsV2(
                                "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
                                input_shape, filter_shape, out_backprop_shape,
                                dilation_, stride_, padding_,
                                /*explicit_paddings=*/{}, data_format_, &dims));

    Tensor* in_backprop;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, input_shape, &in_backprop));

    auto* stream = context->op_device_context()->stream();
    OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));

    bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth;
    if (!is_grouped_convolution && dims.filter_size(0) == 1 &&
        dims.filter_size(1) == 1 && dims.filter_size(2) == 1 &&
        dims.dilation(0) == 1 && dims.dilation(1) == 1 &&
        dims.dilation(2) == 1 && dims.stride(0) == 1 && dims.stride(1) == 1 &&
        dims.stride(2) == 1 && data_format_ == FORMAT_NHWC) {
      const uint64 m = dims.batch_size * dims.input_size(0) *
                       dims.input_size(1) * dims.input_size(2);
      const uint64 k = dims.out_depth;
      const uint64 n = dims.in_depth;

      auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
                                  out_backprop.template flat<T>().size());
      auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
                                  filter.template flat<T>().size());
      auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
                                  in_backprop->template flat<T>().size());

      auto transpose = se::blas::Transpose::kTranspose;
      auto no_transpose = se::blas::Transpose::kNoTranspose;

      bool blas_launch_status =
          stream
              ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
                             a_ptr, k, 0.0f, &c_ptr, n)
              .ok();
      if (!blas_launch_status) {
        context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
                                            ", n=", n, ", k=", k));
      }
      return;
    } else if (!is_grouped_convolution &&
               dims.filter_size(0) == dims.input_size(0) &&
               dims.filter_size(1) == dims.input_size(1) &&
               dims.filter_size(2) == dims.input_size(2) &&
               padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
      const uint64 m = dims.batch_size;
      const uint64 k = dims.out_depth;
      const uint64 n = dims.input_size(0) * dims.input_size(1) *
                       dims.input_size(2) * dims.in_depth;

      auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
                                  out_backprop.template flat<T>().size());
      auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
                                  filter.template flat<T>().size());
      auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
                                  in_backprop->template flat<T>().size());

      auto transpose = se::blas::Transpose::kTranspose;
      auto no_transpose = se::blas::Transpose::kNoTranspose;

      bool blas_launch_status =
          stream
              ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
                             a_ptr, k, 0.0f, &c_ptr, n)
              .ok();
      if (!blas_launch_status) {
        context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
                                            ", n=", n, ", k=", k));
      }
      return;
    }

    int padding_planes = dims.SpatialPadding(padding_, 0);
    int padding_rows = dims.SpatialPadding(padding_, 1);
    int padding_cols = dims.SpatialPadding(padding_, 2);
    const bool planes_odd = (padding_planes % 2 != 0);
    const bool rows_odd = (padding_rows % 2 != 0);
    const bool cols_odd = (padding_cols % 2 != 0);

    TensorShape compatible_input_shape;
    if (rows_odd || cols_odd || planes_odd) {
      // cuDNN only supports the same amount of padding on both sides.
      compatible_input_shape = {
          dims.batch_size,
          dims.in_depth,
          dims.input_size(0) + planes_odd,
          dims.input_size(1) + rows_odd,
          dims.input_size(2) + cols_odd,
      };
    } else {
      compatible_input_shape = {dims.batch_size, dims.in_depth,
                                dims.input_size(0), dims.input_size(1),
                                dims.input_size(2)};
    }

    CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
        << "Negative paddings: (" << padding_rows << ", " << padding_cols
        << ", " << padding_planes << ")";
    se::dnn::BatchDescriptor input_desc(3);
    input_desc.set_count(dims.batch_size)
        .set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
        .set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3))
        .set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2))
        .set_feature_map_count(dims.in_depth)
        .set_layout(se::dnn::DataLayout::kBatchDepthYX);
    se::dnn::BatchDescriptor output_desc(3);
    output_desc.set_count(dims.batch_size)
        .set_spatial_dim(DimIndex::X, dims.output_size(2))
        .set_spatial_dim(DimIndex::Y, dims.output_size(1))
        .set_spatial_dim(DimIndex::Z, dims.output_size(0))
        .set_feature_map_count(dims.out_depth)
        .set_layout(se::dnn::DataLayout::kBatchDepthYX);
    se::dnn::FilterDescriptor filter_desc(3);
    filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
        .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
        .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
        .set_input_feature_map_count(filter_shape.dim_size(3))
        .set_output_feature_map_count(filter_shape.dim_size(4));
    se::dnn::ConvolutionDescriptor conv_desc(3);
    conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
        .set_dilation_rate(DimIndex::Y, dims.dilation(1))
        .set_dilation_rate(DimIndex::Z, dims.dilation(0))
        .set_filter_stride(DimIndex::X, dims.stride(2))
        .set_filter_stride(DimIndex::Y, dims.stride(1))
        .set_filter_stride(DimIndex::Z, dims.stride(0))
        .set_zero_padding(DimIndex::X, padding_cols / 2)
        .set_zero_padding(DimIndex::Y, padding_rows / 2)
        .set_zero_padding(DimIndex::Z, padding_planes / 2)
        .set_group_count(dims.in_depth / filter_shape.dim_size(3));

    // Shape: out, in, z, y, x.
    Tensor transformed_filter;
    OP_REQUIRES_OK(
        context, context->allocate_temp(
                     DataTypeToEnum<T>::value,
                     TensorShape({filter_shape.dim_size(4),
                                  filter_shape.dim_size(3), dims.filter_size(0),
                                  dims.filter_size(1), dims.filter_size(2)}),
                     &transformed_filter));
    functor::TransformFilter<GPUDevice, T, int, 5>()(
        context->eigen_device<GPUDevice>(), FORMAT_OIHW,
        To32Bit(filter.tensor<T, 5>()),
        To32Bit(transformed_filter.tensor<T, 5>()));

    // Shape: batch, filters, z, y, x.
    Tensor transformed_out_backprop;
    if (data_format_ == FORMAT_NHWC) {
      TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
                                dims.output_size(0), dims.output_size(1),
                                dims.output_size(2)};
      if (dims.out_depth > 1) {
        OP_REQUIRES_OK(context, context->allocate_temp(
                                    DataTypeToEnum<T>::value, nchw_shape,
                                    &transformed_out_backprop));
        functor::NHWCToNCHW<GPUDevice, T, 5>()(
            context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
            transformed_out_backprop.tensor<T, 5>());
      } else {
        CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
      }
    } else {
      transformed_out_backprop = out_backprop;
    }
    // Shape: batch, filters, z, y, x.
    Tensor pre_transformed_in_backprop;
    OP_REQUIRES_OK(
        context,
        context->allocate_temp(DataTypeToEnum<T>::value, compatible_input_shape,
                               &pre_transformed_in_backprop));

    auto out_backprop_ptr =
        AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
                       transformed_out_backprop.template flat<T>().size());
    auto filter_ptr =
        AsDeviceMemory(transformed_filter.template flat<T>().data(),
                       transformed_filter.template flat<T>().size());
    auto in_backprop_ptr =
        AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
                       pre_transformed_in_backprop.template flat<T>().size());

    static int64 ConvolveBackwardDataScratchSize = GetDnnWorkspaceLimit(
        "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32);  // 4GB by default

    const int device_id = stream->parent()->device_ordinal();
    DataType dtype = context->input(0).dtype();
    const ConvParameters conv_parameters = {
        dims.batch_size,
        dims.in_depth,
        {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
        FORMAT_NCHW,
        dims.out_depth,
        {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
        {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
        {{dims.stride(0), dims.stride(1), dims.stride(2)}},
        {{padding_planes, padding_rows, padding_cols}},
        dtype,
        device_id,
        conv_desc.group_count()};

    using se::dnn::AlgorithmConfig;
    using se::dnn::AlgorithmDesc;
    using se::dnn::ProfileResult;
    AlgorithmConfig algorithm_config;
    if (cudnn_use_autotune_ && !AutoTuneConv3dBwdData::GetInstance()->Find(
                                   conv_parameters, &algorithm_config)) {
#if GOOGLE_CUDA
      se::TfAllocatorAdapter tf_allocator_adapter(
          context->device()->GetAllocator({}), stream);
      se::cuda::RedzoneAllocator rz_allocator(
          stream, &tf_allocator_adapter, se::cuda::PtxCompilationOptions());
      se::DeviceMemory<T> in_backprop_ptr_rz(
          WrapRedzoneBestEffort(&rz_allocator, in_backprop_ptr));
      std::vector<AlgorithmDesc> algorithms;
      CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
          conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
              stream->parent()),
          &algorithms));
      ProfileResult best_result;
      ProfileResult best_result_no_scratch;
      std::vector<tensorflow::AutotuneResult> results;
      for (auto profile_algorithm : algorithms) {
        // TODO(zhengxq): profile each algorithm multiple times to better
        // accuracy.
        DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
                                              context);
        se::cuda::RedzoneAllocator rz_scratch_allocator(
            stream, &tf_allocator_adapter, se::cuda::PtxCompilationOptions(),
            /*memory_limit=*/ConvolveBackwardDataScratchSize);
        se::ScratchAllocator* allocator_used =
            !RedzoneCheckDisabled()
                ? static_cast<se::ScratchAllocator*>(&rz_scratch_allocator)
                : static_cast<se::ScratchAllocator*>(&scratch_allocator);
        ProfileResult profile_result;
        bool cudnn_launch_status =
            stream
                ->ThenConvolveBackwardDataWithAlgorithm(
                    filter_desc, filter_ptr, output_desc, out_backprop_ptr,
                    conv_desc, input_desc, &in_backprop_ptr_rz, allocator_used,
                    AlgorithmConfig(profile_algorithm), &profile_result)
                .ok();
        if (cudnn_launch_status) {
          if (profile_result.is_valid()) {
            results.emplace_back();
            auto& result = results.back();
            result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
            result.mutable_conv()->set_tensor_ops_enabled(
                profile_algorithm.tensor_ops_enabled());
            result.set_scratch_bytes(
                !RedzoneCheckDisabled()
                    ? rz_scratch_allocator
                          .TotalAllocatedBytesExcludingRedzones()
                    : scratch_allocator.TotalByteSize());
            *result.mutable_run_time() = proto_utils::ToDurationProto(
                absl::Milliseconds(profile_result.elapsed_time_in_ms()));

            if (profile_result.elapsed_time_in_ms() <
                best_result.elapsed_time_in_ms()) {
              best_result = profile_result;
            }
            if (scratch_allocator.TotalByteSize() == 0 &&
                profile_result.elapsed_time_in_ms() <
                    best_result_no_scratch.elapsed_time_in_ms()) {
              best_result_no_scratch = profile_result;
            }
            // TODO(george): they don't do results at all??
            CheckRedzones(rz_scratch_allocator, &result);
            CheckRedzones(rz_allocator, &result);
          }
        }
      }
      LogConvAutotuneResults(se::dnn::ConvolutionKind::BACKWARD_DATA,
                             se::dnn::ToDataType<T>::value, in_backprop_ptr,
                             filter_ptr, out_backprop_ptr, input_desc,
                             filter_desc, output_desc, conv_desc,
                             stream->parent(), results);
      OP_REQUIRES(context,
                  best_result.is_valid() || best_result_no_scratch.is_valid(),
                  errors::NotFound("No algorithm worked!"));
      if (best_result.is_valid()) {
        algorithm_config.set_algorithm(best_result.algorithm());
      }
      if (best_result_no_scratch.is_valid()) {
        algorithm_config.set_algorithm_no_scratch(
            best_result_no_scratch.algorithm());
      }
#elif TENSORFLOW_USE_ROCM
      DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
                                            context);
      ProfileResult best_result;
      bool miopen_find_status =
          stream
              ->ThenConvolveBackwardDataWithAlgorithm(
                  filter_desc, filter_ptr, output_desc, out_backprop_ptr,
                  conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
                  AlgorithmConfig(), &best_result)
              .ok();
      OP_REQUIRES(context, miopen_find_status && best_result.is_valid(),
                  errors::NotFound("Failed to find backward data algorithm!"));
      algorithm_config.set_algorithm(best_result.algorithm());
      algorithm_config.set_scratch_size(best_result.scratch_size());
#endif
      AutoTuneConv3dBwdData::GetInstance()->Insert(conv_parameters,
                                                   algorithm_config);
    }
    DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
                                          context);
    bool cudnn_launch_status =
        stream
            ->ThenConvolveBackwardDataWithAlgorithm(
                filter_desc, filter_ptr, output_desc, out_backprop_ptr,
                conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
                algorithm_config, nullptr)
            .ok();

    if (!cudnn_launch_status) {
      context->SetStatus(errors::Internal(
          "cuDNN Backward Data function launch failure : input shape(",
          input_shape.DebugString(), ") filter shape(",
          filter_shape.DebugString(), ")"));
    }

    if (rows_odd || cols_odd || planes_odd) {
      Tensor in_backprop_remove_padding;
      OP_REQUIRES_OK(context,
                     context->allocate_temp(
                         DataTypeToEnum<T>::value,
                         {dims.batch_size, dims.in_depth, dims.input_size(0),
                          dims.input_size(1), dims.input_size(2)},
                         &in_backprop_remove_padding));

      // Remove the padding for odd spatial dimensions.
      functor::PadInput<GPUDevice, T, int, 5>()(
          context->eigen_device<GPUDevice>(),
          To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
                      .tensor<T, 5>()),
          {{0, 0, 0}}, {{-planes_odd, -rows_odd, -cols_odd}},
          To32Bit(in_backprop_remove_padding.tensor<T, 5>()), FORMAT_NCHW);

      pre_transformed_in_backprop = in_backprop_remove_padding;
    }

    if (data_format_ == FORMAT_NHWC) {
      auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
      functor::NCHWToNHWC<GPUDevice, T, 5>()(
          context->eigen_device<GPUDevice>(),
          toConstTensor(pre_transformed_in_backprop).template tensor<T, 5>(),
          in_backprop->tensor<T, 5>());
    } else {
      *in_backprop = pre_transformed_in_backprop;
    }
  }