in tensorflow/tensorflow/core/kernels/conv_grad_ops_3d.cc [1145:1515]
void Compute(OpKernelContext* context) override {
const Tensor& filter = context->input(1);
const TensorShape& filter_shape = filter.shape();
const Tensor& out_backprop = context->input(2);
const TensorShape& out_backprop_shape = out_backprop.shape();
TensorShape input_shape;
if (takes_shape_) {
const Tensor& input_sizes = context->input(0);
OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape));
} else {
input_shape = context->input(0).shape();
}
ConvBackpropDimensions dims;
OP_REQUIRES_OK(context, ConvBackpropComputeDimensionsV2(
"Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
input_shape, filter_shape, out_backprop_shape,
dilation_, stride_, padding_,
/*explicit_paddings=*/{}, data_format_, &dims));
Tensor* in_backprop;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_shape, &in_backprop));
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth;
if (!is_grouped_convolution && dims.filter_size(0) == 1 &&
dims.filter_size(1) == 1 && dims.filter_size(2) == 1 &&
dims.dilation(0) == 1 && dims.dilation(1) == 1 &&
dims.dilation(2) == 1 && dims.stride(0) == 1 && dims.stride(1) == 1 &&
dims.stride(2) == 1 && data_format_ == FORMAT_NHWC) {
const uint64 m = dims.batch_size * dims.input_size(0) *
dims.input_size(1) * dims.input_size(2);
const uint64 k = dims.out_depth;
const uint64 n = dims.in_depth;
auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
out_backprop.template flat<T>().size());
auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
filter.template flat<T>().size());
auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
in_backprop->template flat<T>().size());
auto transpose = se::blas::Transpose::kTranspose;
auto no_transpose = se::blas::Transpose::kNoTranspose;
bool blas_launch_status =
stream
->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
a_ptr, k, 0.0f, &c_ptr, n)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
", n=", n, ", k=", k));
}
return;
} else if (!is_grouped_convolution &&
dims.filter_size(0) == dims.input_size(0) &&
dims.filter_size(1) == dims.input_size(1) &&
dims.filter_size(2) == dims.input_size(2) &&
padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
const uint64 m = dims.batch_size;
const uint64 k = dims.out_depth;
const uint64 n = dims.input_size(0) * dims.input_size(1) *
dims.input_size(2) * dims.in_depth;
auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
out_backprop.template flat<T>().size());
auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
filter.template flat<T>().size());
auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
in_backprop->template flat<T>().size());
auto transpose = se::blas::Transpose::kTranspose;
auto no_transpose = se::blas::Transpose::kNoTranspose;
bool blas_launch_status =
stream
->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
a_ptr, k, 0.0f, &c_ptr, n)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
", n=", n, ", k=", k));
}
return;
}
int padding_planes = dims.SpatialPadding(padding_, 0);
int padding_rows = dims.SpatialPadding(padding_, 1);
int padding_cols = dims.SpatialPadding(padding_, 2);
const bool planes_odd = (padding_planes % 2 != 0);
const bool rows_odd = (padding_rows % 2 != 0);
const bool cols_odd = (padding_cols % 2 != 0);
TensorShape compatible_input_shape;
if (rows_odd || cols_odd || planes_odd) {
// cuDNN only supports the same amount of padding on both sides.
compatible_input_shape = {
dims.batch_size,
dims.in_depth,
dims.input_size(0) + planes_odd,
dims.input_size(1) + rows_odd,
dims.input_size(2) + cols_odd,
};
} else {
compatible_input_shape = {dims.batch_size, dims.in_depth,
dims.input_size(0), dims.input_size(1),
dims.input_size(2)};
}
CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
<< "Negative paddings: (" << padding_rows << ", " << padding_cols
<< ", " << padding_planes << ")";
se::dnn::BatchDescriptor input_desc(3);
input_desc.set_count(dims.batch_size)
.set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
.set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3))
.set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2))
.set_feature_map_count(dims.in_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::BatchDescriptor output_desc(3);
output_desc.set_count(dims.batch_size)
.set_spatial_dim(DimIndex::X, dims.output_size(2))
.set_spatial_dim(DimIndex::Y, dims.output_size(1))
.set_spatial_dim(DimIndex::Z, dims.output_size(0))
.set_feature_map_count(dims.out_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::FilterDescriptor filter_desc(3);
filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
.set_spatial_dim(DimIndex::Y, dims.filter_size(1))
.set_spatial_dim(DimIndex::Z, dims.filter_size(0))
.set_input_feature_map_count(filter_shape.dim_size(3))
.set_output_feature_map_count(filter_shape.dim_size(4));
se::dnn::ConvolutionDescriptor conv_desc(3);
conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
.set_dilation_rate(DimIndex::Y, dims.dilation(1))
.set_dilation_rate(DimIndex::Z, dims.dilation(0))
.set_filter_stride(DimIndex::X, dims.stride(2))
.set_filter_stride(DimIndex::Y, dims.stride(1))
.set_filter_stride(DimIndex::Z, dims.stride(0))
.set_zero_padding(DimIndex::X, padding_cols / 2)
.set_zero_padding(DimIndex::Y, padding_rows / 2)
.set_zero_padding(DimIndex::Z, padding_planes / 2)
.set_group_count(dims.in_depth / filter_shape.dim_size(3));
// Shape: out, in, z, y, x.
Tensor transformed_filter;
OP_REQUIRES_OK(
context, context->allocate_temp(
DataTypeToEnum<T>::value,
TensorShape({filter_shape.dim_size(4),
filter_shape.dim_size(3), dims.filter_size(0),
dims.filter_size(1), dims.filter_size(2)}),
&transformed_filter));
functor::TransformFilter<GPUDevice, T, int, 5>()(
context->eigen_device<GPUDevice>(), FORMAT_OIHW,
To32Bit(filter.tensor<T, 5>()),
To32Bit(transformed_filter.tensor<T, 5>()));
// Shape: batch, filters, z, y, x.
Tensor transformed_out_backprop;
if (data_format_ == FORMAT_NHWC) {
TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
dims.output_size(0), dims.output_size(1),
dims.output_size(2)};
if (dims.out_depth > 1) {
OP_REQUIRES_OK(context, context->allocate_temp(
DataTypeToEnum<T>::value, nchw_shape,
&transformed_out_backprop));
functor::NHWCToNCHW<GPUDevice, T, 5>()(
context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
transformed_out_backprop.tensor<T, 5>());
} else {
CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
}
} else {
transformed_out_backprop = out_backprop;
}
// Shape: batch, filters, z, y, x.
Tensor pre_transformed_in_backprop;
OP_REQUIRES_OK(
context,
context->allocate_temp(DataTypeToEnum<T>::value, compatible_input_shape,
&pre_transformed_in_backprop));
auto out_backprop_ptr =
AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
transformed_out_backprop.template flat<T>().size());
auto filter_ptr =
AsDeviceMemory(transformed_filter.template flat<T>().data(),
transformed_filter.template flat<T>().size());
auto in_backprop_ptr =
AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
pre_transformed_in_backprop.template flat<T>().size());
static int64 ConvolveBackwardDataScratchSize = GetDnnWorkspaceLimit(
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32); // 4GB by default
const int device_id = stream->parent()->device_ordinal();
DataType dtype = context->input(0).dtype();
const ConvParameters conv_parameters = {
dims.batch_size,
dims.in_depth,
{{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
FORMAT_NCHW,
dims.out_depth,
{{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
{{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
{{dims.stride(0), dims.stride(1), dims.stride(2)}},
{{padding_planes, padding_rows, padding_cols}},
dtype,
device_id,
conv_desc.group_count()};
using se::dnn::AlgorithmConfig;
using se::dnn::AlgorithmDesc;
using se::dnn::ProfileResult;
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune_ && !AutoTuneConv3dBwdData::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
#if GOOGLE_CUDA
se::TfAllocatorAdapter tf_allocator_adapter(
context->device()->GetAllocator({}), stream);
se::cuda::RedzoneAllocator rz_allocator(
stream, &tf_allocator_adapter, se::cuda::PtxCompilationOptions());
se::DeviceMemory<T> in_backprop_ptr_rz(
WrapRedzoneBestEffort(&rz_allocator, in_backprop_ptr));
std::vector<AlgorithmDesc> algorithms;
CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
stream->parent()),
&algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
std::vector<tensorflow::AutotuneResult> results;
for (auto profile_algorithm : algorithms) {
// TODO(zhengxq): profile each algorithm multiple times to better
// accuracy.
DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
context);
se::cuda::RedzoneAllocator rz_scratch_allocator(
stream, &tf_allocator_adapter, se::cuda::PtxCompilationOptions(),
/*memory_limit=*/ConvolveBackwardDataScratchSize);
se::ScratchAllocator* allocator_used =
!RedzoneCheckDisabled()
? static_cast<se::ScratchAllocator*>(&rz_scratch_allocator)
: static_cast<se::ScratchAllocator*>(&scratch_allocator);
ProfileResult profile_result;
bool cudnn_launch_status =
stream
->ThenConvolveBackwardDataWithAlgorithm(
filter_desc, filter_ptr, output_desc, out_backprop_ptr,
conv_desc, input_desc, &in_backprop_ptr_rz, allocator_used,
AlgorithmConfig(profile_algorithm), &profile_result)
.ok();
if (cudnn_launch_status) {
if (profile_result.is_valid()) {
results.emplace_back();
auto& result = results.back();
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
result.mutable_conv()->set_tensor_ops_enabled(
profile_algorithm.tensor_ops_enabled());
result.set_scratch_bytes(
!RedzoneCheckDisabled()
? rz_scratch_allocator
.TotalAllocatedBytesExcludingRedzones()
: scratch_allocator.TotalByteSize());
*result.mutable_run_time() = proto_utils::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
if (profile_result.elapsed_time_in_ms() <
best_result.elapsed_time_in_ms()) {
best_result = profile_result;
}
if (scratch_allocator.TotalByteSize() == 0 &&
profile_result.elapsed_time_in_ms() <
best_result_no_scratch.elapsed_time_in_ms()) {
best_result_no_scratch = profile_result;
}
// TODO(george): they don't do results at all??
CheckRedzones(rz_scratch_allocator, &result);
CheckRedzones(rz_allocator, &result);
}
}
}
LogConvAutotuneResults(se::dnn::ConvolutionKind::BACKWARD_DATA,
se::dnn::ToDataType<T>::value, in_backprop_ptr,
filter_ptr, out_backprop_ptr, input_desc,
filter_desc, output_desc, conv_desc,
stream->parent(), results);
OP_REQUIRES(context,
best_result.is_valid() || best_result_no_scratch.is_valid(),
errors::NotFound("No algorithm worked!"));
if (best_result.is_valid()) {
algorithm_config.set_algorithm(best_result.algorithm());
}
if (best_result_no_scratch.is_valid()) {
algorithm_config.set_algorithm_no_scratch(
best_result_no_scratch.algorithm());
}
#elif TENSORFLOW_USE_ROCM
DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
context);
ProfileResult best_result;
bool miopen_find_status =
stream
->ThenConvolveBackwardDataWithAlgorithm(
filter_desc, filter_ptr, output_desc, out_backprop_ptr,
conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
AlgorithmConfig(), &best_result)
.ok();
OP_REQUIRES(context, miopen_find_status && best_result.is_valid(),
errors::NotFound("Failed to find backward data algorithm!"));
algorithm_config.set_algorithm(best_result.algorithm());
algorithm_config.set_scratch_size(best_result.scratch_size());
#endif
AutoTuneConv3dBwdData::GetInstance()->Insert(conv_parameters,
algorithm_config);
}
DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
context);
bool cudnn_launch_status =
stream
->ThenConvolveBackwardDataWithAlgorithm(
filter_desc, filter_ptr, output_desc, out_backprop_ptr,
conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
algorithm_config, nullptr)
.ok();
if (!cudnn_launch_status) {
context->SetStatus(errors::Internal(
"cuDNN Backward Data function launch failure : input shape(",
input_shape.DebugString(), ") filter shape(",
filter_shape.DebugString(), ")"));
}
if (rows_odd || cols_odd || planes_odd) {
Tensor in_backprop_remove_padding;
OP_REQUIRES_OK(context,
context->allocate_temp(
DataTypeToEnum<T>::value,
{dims.batch_size, dims.in_depth, dims.input_size(0),
dims.input_size(1), dims.input_size(2)},
&in_backprop_remove_padding));
// Remove the padding for odd spatial dimensions.
functor::PadInput<GPUDevice, T, int, 5>()(
context->eigen_device<GPUDevice>(),
To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
.tensor<T, 5>()),
{{0, 0, 0}}, {{-planes_odd, -rows_odd, -cols_odd}},
To32Bit(in_backprop_remove_padding.tensor<T, 5>()), FORMAT_NCHW);
pre_transformed_in_backprop = in_backprop_remove_padding;
}
if (data_format_ == FORMAT_NHWC) {
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
functor::NCHWToNHWC<GPUDevice, T, 5>()(
context->eigen_device<GPUDevice>(),
toConstTensor(pre_transformed_in_backprop).template tensor<T, 5>(),
in_backprop->tensor<T, 5>());
} else {
*in_backprop = pre_transformed_in_backprop;
}
}