in crypten/gradients.py [0:0]
def backward(ctx, grad_output):
# Gradient function adapts code from:
# https://github.com/pytorch/pytorch/blob/master/torch/nn/grad.py
# get input, kernel, and sizes:
input, kernel, padding, stride, dilation, groups = ctx.saved_tensors
batch_size = input.size(0)
out_channels, in_channels, kernel_size = kernel.size()
in_channels *= groups
assert input.size(1) == in_channels, "wrong number of input channels"
assert grad_output.size(1) == out_channels, "wrong number of output channels"
assert grad_output.size(0) == batch_size, "wrong batch size"
# TODO: Implement conv1d gradient under following condition:
if groups > 1 and input.size(1) > groups:
raise NotImplementedError(
"conv1d backward with groups > 1 and in_channels > groups not implemented"
)
# compute gradient with respect to input:
# TODO: Eliminate dependency on torch internal function by implementing in util
output_padding = torch.nn.grad._grad_input_padding(
grad_output,
input.size(),
stride,
padding,
(kernel_size,),
dilation=dilation,
)
grad_input = grad_output.conv_transpose1d(
kernel,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
)
# compute gradient with respect to kernel:
grad_output = grad_output.repeat(1, in_channels // groups, 1)
grad_output = grad_output.view(
grad_output.size(0) * grad_output.size(1), 1, grad_output.size(2)
)
input = input.view(1, input.size(0) * input.size(1), input.size(2))
grad_kernel = input.conv1d(
grad_output,
stride=dilation,
padding=padding,
dilation=stride,
groups=in_channels * batch_size,
)
grad_kernel = grad_kernel.view(
batch_size, grad_kernel.size(1) // batch_size, grad_kernel.size(2)
)
grad_kernel = grad_kernel.sum(dim=0)
grad_kernel = grad_kernel.view(
in_channels // groups, out_channels, grad_kernel.size(1)
)
grad_kernel = grad_kernel.transpose(0, 1).narrow(2, 0, kernel_size)
return (grad_input, grad_kernel)