in crypten/gradients.py [0:0]
def backward(ctx, grad_output):
# Gradient function adapts code from:
# https://github.com/pytorch/pytorch/blob/master/torch/nn/grad.py
# get input, kernel, and sizes:
input, kernel, padding, stride, dilation, groups = ctx.saved_tensors
batch_size = input.size(0)
out_channels, in_channels, kernel_size_y, kernel_size_x = kernel.size()
in_channels *= groups
assert input.size(1) == in_channels, "wrong number of input channels"
assert grad_output.size(1) == out_channels, "wrong number of output channels"
assert grad_output.size(0) == batch_size, "wrong batch size"
# TODO: Implement conv2d gradient under following condition:
if groups > 1 and input.size(1) > groups:
raise NotImplementedError(
"conv2d backward with groups > 1 and in_channels > groups not implemented"
)
# compute gradient with respect to input:
# TODO: Eliminate dependency on torch internal function by implementing in util
output_padding = torch.nn.grad._grad_input_padding(
grad_output,
input.size(),
stride,
padding,
(kernel_size_y, kernel_size_x),
dilation=dilation,
)
grad_input = grad_output.conv_transpose2d(
kernel,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
)
# compute gradient with respect to kernel:
grad_output = grad_output.repeat(1, in_channels // groups, 1, 1)
grad_output = grad_output.view(
grad_output.size(0) * grad_output.size(1),
1,
grad_output.size(2),
grad_output.size(3),
)
input = input.view(
1, input.size(0) * input.size(1), input.size(2), input.size(3)
)
# dilation and stride are swapped based on PyTorch's conv2d_weight implementation
grad_kernel = input.conv2d(
grad_output,
stride=dilation,
padding=padding,
dilation=stride,
groups=in_channels * batch_size,
)
grad_kernel = grad_kernel.view(
batch_size,
grad_kernel.size(1) // batch_size,
grad_kernel.size(2),
grad_kernel.size(3),
)
grad_kernel = (
grad_kernel.sum(0)
.view(
in_channels // groups,
out_channels,
grad_kernel.size(2),
grad_kernel.size(3),
)
.transpose(0, 1)
)
grad_kernel = grad_kernel.narrow(2, 0, kernel_size_y)
grad_kernel = grad_kernel.narrow(3, 0, kernel_size_x)
return (grad_input, grad_kernel)