in nestedtensor/csrc/conv2d.cpp [108:195]
Tensor NestedTensor_cudnn_convolution_relu(
const Tensor& input_,
const Tensor& weight,
const c10::optional<Tensor>& bias,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
int64_t groups) {
Tensor input = input_;
TORCH_CHECK(get_dim(input) == 4, "Expected input to be dim 4, but got ", get_dim(input), ".");
#ifdef WITH_CUDA
auto self_opt_sizes = get_opt_sizes(input);
if (is_nested_tensor_impl(input) &&
!is_nested_tensor_impl(weight) &&
(input.dtype() == torch::kFloat16 || input.dtype() == torch::kFloat32)) {
if (get_dim(input) == 4 && !bias && weight.size(2) == 1 && weight.size(3) == 1 &&
stride[0] == 1 && stride[1] == 1 &&
padding[0] == 0 && padding[1] == 0 &&
dilation[0] == 1 && dilation[1] == 1 &&
groups == 1 &&
*self_opt_sizes[0] &&
*self_opt_sizes[1] &&
get_is_cuda(input)
) {
if (get_is_contiguous(input, c10::MemoryFormat::ChannelsLast)) {
Tensor input_buffer = get_buffer(input);
input_buffer = input_buffer.view({-1, weight.size(1)});
at::Tensor result_buffer = at::matmul(input_buffer,
weight.reshape({weight.size(0), weight.size(1)}).transpose(0, 1));
int64_t weight_size_0 = weight.size(0);
auto new_sizes = map_efficient_size([&weight_size_0](int64_t* size_ptr, int64_t size) {
size_ptr[0] = weight_size_0;
}, get_efficient_nested_size(input));
auto new_strides = map_efficient_size([] (int64_t* size_ptr, int64_t size) {
int64_t tmp2 = size_ptr[2];
size_ptr[2] = size_ptr[0];
int64_t tmp1 = size_ptr[1];
size_ptr[1] = size_ptr[2] * tmp2;
size_ptr[0] = 1;
}, new_sizes);
return wrap_buffer(result_buffer.view(-1), new_sizes, new_strides);
}
if (get_is_contiguous(input)) {
input = transpose_nchw_nhwc(input);
Tensor input_buffer = get_buffer(input);
input_buffer = input_buffer.reshape({-1, weight.size(1)});
at::Tensor result_buffer = at::matmul(input_buffer,
weight.reshape({weight.size(0), weight.size(1)}).transpose(0, 1));
int64_t weight_size_0 = weight.size(0);
auto new_sizes = map_efficient_size([&weight_size_0](int64_t* size_ptr, int64_t size) {
size_ptr[2] = weight_size_0;
}, get_efficient_nested_size(input));
Tensor result = wrap_buffer(result_buffer.reshape(-1), new_sizes);
return transpose_nhwc_nchw(result);
}
}
}
#endif
if (input.dtype() == torch::kFloat16) {
at::Tensor data = to_padded_tensor(input, 0);
at::Tensor result_data = at::cudnn_convolution_relu(data, weight, bias, stride, padding, dilation, groups);
auto new_sizes = map_efficient_size([&weight, &stride, &padding, &groups, &dilation](int64_t* size_ptr, int64_t size) {
size_ptr[0] = weight.size(0);
size_ptr[1] = ((size_ptr[1] + 2 * padding[0] - dilation[0] * (weight.size(2) - 1) - 1) / stride[0]) + 1;
size_ptr[2] = ((size_ptr[2] + 2 * padding[1] - dilation[1] * (weight.size(3) - 1) - 1) / stride[1]) + 1;
}, get_efficient_nested_size(input));
Tensor result = from_padded_tensor(result_data, new_sizes);
if (get_is_contiguous(input, c10::MemoryFormat::ChannelsLast)) {
return NestedTensor_contiguous(result, c10::MemoryFormat::ChannelsLast);
}
return result;
}
if (bias) {
return map_nested_tensor(
[&stride, &padding, &dilation, &groups](at::Tensor input, at::Tensor weight, at::Tensor bias) {
return at::cudnn_convolution_relu(input.unsqueeze(0), weight, bias, stride, padding, dilation, groups).squeeze(0);
},
input,
weight,
*bias);
}
return map_nested_tensor(
[&stride, &padding, &dilation, &groups](at::Tensor input, at::Tensor weight) {
return at::cudnn_convolution_relu(input.unsqueeze(0), weight, c10::nullopt, stride, padding, dilation, groups).squeeze(0);
},
input,
weight);
}