in modules.py [0:0]
def backward(ctx, grad_output):
# get input, kernel, and sizes:
input, weight = ctx.saved_tensors
batch_size = input.size(0)
out_channels, in_channels, weight_size_y, weight_size_x = weight.size()
assert input.size(1) == in_channels, "wrong number of input channels"
assert grad_output.size(1) == out_channels, "wrong number of output channels"
assert grad_output.size(0) == batch_size, "wrong batch size"
# compute gradient with respect to input:
grad_input = torch.nn.grad.conv2d_input(
input.size(),
weight,
grad_output,
stride=ctx.stride,
padding=ctx.padding,
)
# compute per-example gradient with respect to weights:
out_channels, in_channels, weight_size_y, weight_size_x = weight.size()
grad_output = grad_output.contiguous().repeat(1, in_channels, 1, 1)
grad_output = grad_output.contiguous().view(
grad_output.shape[0] * grad_output.shape[1],
1,
grad_output.shape[2],
grad_output.shape[3],
)
input = input.contiguous().view(
1,
input.shape[0] * input.shape[1],
input.shape[2],
input.shape[3],
)
grad_weight = torch.conv2d(
input,
grad_output,
None,
1,
ctx.padding,
ctx.stride,
in_channels * batch_size,
)
grad_weight = grad_weight.contiguous().view(
batch_size,
grad_weight.shape[1] // batch_size,
grad_weight.shape[2],
grad_weight.shape[3],
)
# compute norm of per-example weight gradients:
grad_norm = torch.norm(
grad_weight.view(batch_size, -1), p='fro', dim=1, keepdim=True,
).view(batch_size, 1, 1, 1)
# aggregate the clipped per-example weight gradients:
multiplier = _get_multipliers(grad_norm, ctx.clip)
grad_weight = grad_weight.mul_(multiplier).sum(dim=0)
grad_weight = grad_weight.view(
in_channels,
out_channels,
grad_weight.shape[1],
grad_weight.shape[2]
).transpose(0, 1).narrow(2, 0, weight_size_y).narrow(3, 0, weight_size_x)
# add noise to gradient:
grad_weight += torch.randn_like(grad_weight) * ctx.clip * ctx.std
# return gradients:
return grad_input, grad_weight, None, None, None, None