void Compute()

in tensorflow/tensorflow/core/kernels/conv_grad_ops_3d.cc [1584:1918]


  void Compute(OpKernelContext* context) override {
    const Tensor& input = context->input(0);
    const TensorShape& input_shape = input.shape();

    const Tensor& out_backprop = context->input(2);
    const TensorShape& out_backprop_shape = out_backprop.shape();

    TensorShape filter_shape;
    if (takes_shape_) {
      const Tensor& filter_sizes = context->input(1);
      OP_REQUIRES_OK(context, MakeShape(filter_sizes, &filter_shape));
    } else {
      filter_shape = context->input(1).shape();
    }

    ConvBackpropDimensions dims;
    OP_REQUIRES_OK(
        context,
        ConvBackpropComputeDimensionsV2(
            "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3, input_shape,
            filter_shape, out_backprop_shape, dilation_, stride_, padding_,
            /*explicit_paddings=*/{}, data_format_, &dims));

    Tensor* filter_backprop;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, filter_shape, &filter_backprop));

    auto* stream = context->op_device_context()->stream();
    OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));

    bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth;
    if (!is_grouped_convolution && dims.filter_size(1) == 1 &&
        dims.filter_size(2) == 1 && dims.filter_size(0) == 1 &&
        dims.dilation(2) == 1 && dims.dilation(1) == 1 &&
        dims.dilation(0) == 1 && dims.stride(2) == 1 && dims.stride(1) == 1 &&
        dims.stride(0) == 1 && data_format_ == FORMAT_NHWC) {
      const uint64 m = dims.in_depth;
      const uint64 k = dims.batch_size * dims.input_size(1) *
                       dims.input_size(2) * dims.input_size(0);
      const uint64 n = dims.out_depth;

      // The shape of output backprop is
      //   [batch, out_z, out_y, out_x, out_depth]
      // From cublas's perspective, it is: n x k
      auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
                                  out_backprop.template flat<T>().size());

      // The shape of input is:
      //   [batch, in_z, in_y, in_x, in_depth],
      // From cublas's perspective, it is: m x k
      auto b_ptr = AsDeviceMemory(input.template flat<T>().data(),
                                  input.template flat<T>().size());

      // The shape of the filter backprop is:
      //   [1, 1, 1, in_depth, out_depth]
      // From cublas's perspective, it is: n x m
      auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
                                  filter_backprop->template flat<T>().size());

      bool blas_launch_status =
          stream
              ->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
                             se::blas::Transpose::kTranspose, n, m, k, 1.0f,
                             a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n)
              .ok();
      if (!blas_launch_status) {
        context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
                                            ", n=", n, ", k=", k));
      }
      return;
    } else if (!is_grouped_convolution &&
               dims.filter_size(0) == dims.input_size(0) &&
               dims.filter_size(1) == dims.input_size(1) &&
               dims.filter_size(2) == dims.input_size(2) &&
               padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
      const uint64 m = dims.input_size(0) * dims.input_size(1) *
                       dims.input_size(2) * dims.in_depth;
      const uint64 k = dims.batch_size;
      const uint64 n = dims.out_depth;

      auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
                                  input.template flat<T>().size());
      auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
                                  out_backprop.template flat<T>().size());
      auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
                                  filter_backprop->template flat<T>().size());

      bool blas_launch_status =
          stream
              ->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
                             se::blas::Transpose::kTranspose, n, m, k, 1.0f,
                             b_ptr, n, a_ptr, m, 0.0f, &c_ptr, n)
              .ok();
      if (!blas_launch_status) {
        context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
                                            ", n=", n, ", k=", k));
      }
      return;
    }

    int padding_planes = dims.SpatialPadding(padding_, 0);
    int padding_rows = dims.SpatialPadding(padding_, 1);
    int padding_cols = dims.SpatialPadding(padding_, 2);
    const bool planes_odd = (padding_planes % 2 != 0);
    const bool rows_odd = (padding_rows % 2 != 0);
    const bool cols_odd = (padding_cols % 2 != 0);

    Tensor compatible_input;
    if (rows_odd || cols_odd || planes_odd) {
      OP_REQUIRES_OK(context,
                     context->allocate_temp(
                         DataTypeToEnum<T>::value,
                         ShapeFromFormat(data_format_, dims.batch_size,
                                         {{dims.input_size(0) + planes_odd,
                                           dims.input_size(1) + rows_odd,
                                           dims.input_size(2) + cols_odd}},
                                         dims.in_depth),
                         &compatible_input));
      functor::PadInput<GPUDevice, T, int, 5>()(
          context->template eigen_device<GPUDevice>(),
          To32Bit(input.tensor<T, 5>()), {{0, 0, 0}},
          {{planes_odd, rows_odd, cols_odd}},
          To32Bit(compatible_input.tensor<T, 5>()), data_format_);
    } else {
      compatible_input = input;
    }

    CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
        << "Negative paddings: (" << padding_rows << ", " << padding_cols
        << ", " << padding_planes << ")";
    se::dnn::BatchDescriptor input_desc(3);
    input_desc.set_count(dims.batch_size)
        .set_spatial_dim(DimIndex::X,
                         GetTensorDim(compatible_input, data_format_, '2'))
        .set_spatial_dim(DimIndex::Y,
                         GetTensorDim(compatible_input, data_format_, '1'))
        .set_spatial_dim(DimIndex::Z,
                         GetTensorDim(compatible_input, data_format_, '0'))
        .set_feature_map_count(dims.in_depth)
        .set_layout(se::dnn::DataLayout::kBatchDepthYX);
    se::dnn::BatchDescriptor output_desc(3);
    output_desc.set_count(dims.batch_size)
        .set_spatial_dim(DimIndex::X, dims.output_size(2))
        .set_spatial_dim(DimIndex::Y, dims.output_size(1))
        .set_spatial_dim(DimIndex::Z, dims.output_size(0))
        .set_feature_map_count(dims.out_depth)
        .set_layout(se::dnn::DataLayout::kBatchDepthYX);
    se::dnn::FilterDescriptor filter_desc(3);
    filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
        .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
        .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
        .set_input_feature_map_count(filter_shape.dim_size(3))
        .set_output_feature_map_count(filter_shape.dim_size(4));
    se::dnn::ConvolutionDescriptor conv_desc(3);
    conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
        .set_dilation_rate(DimIndex::Y, dims.dilation(1))
        .set_dilation_rate(DimIndex::Z, dims.dilation(0))
        .set_filter_stride(DimIndex::X, dims.stride(2))
        .set_filter_stride(DimIndex::Y, dims.stride(1))
        .set_filter_stride(DimIndex::Z, dims.stride(0))
        .set_zero_padding(DimIndex::X, padding_cols / 2)
        .set_zero_padding(DimIndex::Y, padding_rows / 2)
        .set_zero_padding(DimIndex::Z, padding_planes / 2)
        .set_group_count(dims.in_depth / filter_shape.dim_size(3));
    Tensor pre_transformed_filter_backprop;
    OP_REQUIRES_OK(
        context, context->allocate_temp(
                     DataTypeToEnum<T>::value,
                     TensorShape({filter_shape.dim_size(4),
                                  filter_shape.dim_size(3), dims.filter_size(0),
                                  dims.filter_size(1), dims.filter_size(2)}),
                     &pre_transformed_filter_backprop));

    Tensor transformed_out_backprop;
    if (data_format_ == FORMAT_NHWC) {
      TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
                                dims.output_size(0), dims.output_size(1),
                                dims.output_size(2)};
      OP_REQUIRES_OK(
          context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
                                          &transformed_out_backprop));
      if (dims.out_depth > 1) {
        functor::NHWCToNCHW<GPUDevice, T, 5>()(
            context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
            transformed_out_backprop.tensor<T, 5>());
      } else {
        CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
      }
    } else {
      transformed_out_backprop = out_backprop;
    }
    Tensor transformed_input;
    if (data_format_ == FORMAT_NHWC) {
      TensorShape nchw_shape = {
          dims.batch_size, dims.in_depth, compatible_input.dim_size(1),
          compatible_input.dim_size(2), compatible_input.dim_size(3)};
      if (dims.in_depth > 1) {
        OP_REQUIRES_OK(context,
                       context->allocate_temp(DataTypeToEnum<T>::value,
                                              nchw_shape, &transformed_input));
        functor::NHWCToNCHW<GPUDevice, T, 5>()(
            context->eigen_device<GPUDevice>(),
            const_cast<const Tensor&>(compatible_input).tensor<T, 5>(),
            transformed_input.tensor<T, 5>());
      } else {
        CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape));
      }
    } else {
      transformed_input = compatible_input;
    }

    auto out_backprop_ptr =
        AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
                       transformed_out_backprop.template flat<T>().size());
    auto filter_backprop_ptr = AsDeviceMemory(
        pre_transformed_filter_backprop.template flat<T>().data(),
        pre_transformed_filter_backprop.template flat<T>().size());
    auto input_ptr =
        AsDeviceMemory(transformed_input.template flat<T>().data(),
                       transformed_input.template flat<T>().size());

    static int64 ConvolveBackwardFilterScratchSize = GetDnnWorkspaceLimit(
        "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32);  // 4GB by default

    const int device_id = stream->parent()->device_ordinal();
    DataType dtype = input.dtype();
    const ConvParameters conv_parameters = {
        dims.batch_size,
        dims.in_depth,
        {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
        FORMAT_NCHW,
        dims.out_depth,
        {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
        {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
        {{dims.stride(0), dims.stride(1), dims.stride(2)}},
        {{padding_planes, padding_rows, padding_cols}},
        dtype,
        device_id,
        conv_desc.group_count()};

    using se::dnn::AlgorithmConfig;
    using se::dnn::AlgorithmDesc;
    using se::dnn::ProfileResult;
    AlgorithmConfig algorithm_config;
    if (cudnn_use_autotune_ && !AutoTuneConv3dBwdFilter::GetInstance()->Find(
                                   conv_parameters, &algorithm_config)) {
#if GOOGLE_CUDA
      std::vector<AlgorithmDesc> algorithms;
      CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
          conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
              stream->parent()),
          &algorithms));
      ProfileResult best_result;
      ProfileResult best_result_no_scratch;
      for (auto profile_algorithm : algorithms) {
        // TODO(zhengxq): profile each algorithm multiple times to better
        // accuracy.
        DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
                                              context);
        ProfileResult profile_result;
        bool cudnn_launch_status =
            stream
                ->ThenConvolveBackwardFilterWithAlgorithm(
                    input_desc, input_ptr, output_desc, out_backprop_ptr,
                    conv_desc, filter_desc, &filter_backprop_ptr,
                    &scratch_allocator, AlgorithmConfig(profile_algorithm),
                    &profile_result)
                .ok();
        if (cudnn_launch_status) {
          if (profile_result.is_valid()) {
            if (profile_result.elapsed_time_in_ms() <
                best_result.elapsed_time_in_ms()) {
              best_result = profile_result;
            }
            if (scratch_allocator.TotalByteSize() == 0 &&
                profile_result.elapsed_time_in_ms() <
                    best_result_no_scratch.elapsed_time_in_ms()) {
              best_result_no_scratch = profile_result;
            }
          }
        }
      }
      OP_REQUIRES(context,
                  best_result.is_valid() || best_result_no_scratch.is_valid(),
                  errors::NotFound("No algorithm worked!"));
      if (best_result.is_valid()) {
        algorithm_config.set_algorithm(best_result.algorithm());
      }
      if (best_result_no_scratch.is_valid()) {
        algorithm_config.set_algorithm_no_scratch(
            best_result_no_scratch.algorithm());
      }
#elif TENSORFLOW_USE_ROCM
      DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
                                            context);
      ProfileResult best_result;
      bool miopen_find_status =
          stream
              ->ThenConvolveBackwardFilterWithAlgorithm(
                  input_desc, input_ptr, output_desc, out_backprop_ptr,
                  conv_desc, filter_desc, &filter_backprop_ptr,
                  &scratch_allocator, AlgorithmConfig(), &best_result)
              .ok();
      OP_REQUIRES(
          context, miopen_find_status && best_result.is_valid(),
          errors::NotFound("Failed to find backward filter algorithm!"));
      algorithm_config.set_algorithm(best_result.algorithm());
      algorithm_config.set_scratch_size(best_result.scratch_size());
#endif
      AutoTuneConv3dBwdFilter::GetInstance()->Insert(conv_parameters,
                                                     algorithm_config);
    }
    DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
                                          context);
    bool cudnn_launch_status =
        stream
            ->ThenConvolveBackwardFilterWithAlgorithm(
                input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
                filter_desc, &filter_backprop_ptr, &scratch_allocator,
                algorithm_config, nullptr)
            .ok();

    if (!cudnn_launch_status) {
      context->SetStatus(errors::Internal(
          "cuDNN Backward Filter function launch failure : input shape(",
          input_shape.DebugString(), ") filter shape(",
          filter_shape.DebugString(), ")"));
    }

    auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
    functor::ReverseTransformFilter<GPUDevice, T, 5>()(
        context->eigen_device<GPUDevice>(), /*src_filter_format=*/FORMAT_OIHW,
        toConstTensor(pre_transformed_filter_backprop).template tensor<T, 5>(),
        filter_backprop->tensor<T, 5>());
  }