in modules.py [0:0]
def backward(ctx, grad_output):
# unpack context from forward pass:
input, weight, bias = ctx.saved_tensors
# determine dimensions over which affine transform is broadcasted:
stats_dimensions = list(range(grad_output.dim()))
stats_dimensions.pop(1)
# shape for broadcasting weights with grad_output:
broadcast_shape = [1] * input.dim()
broadcast_shape[1] = input.shape[1]
weight = weight.reshape(broadcast_shape)
bias = bias.reshape(broadcast_shape)
# compute gradient with respect to input:
grad_input = grad_output.mul(weight)
# compute per-example gradient with respect to weights and biases:
grad_weight = grad_output.mul(input)
grad_bias = grad_output.clone()
# compute norms of per-example gradients:
batch_size = grad_output.size(0)
grad_weight_norm = torch.norm(
grad_weight.view(batch_size, -1), p='fro', dim=1, keepdim=True,
)
grad_bias_norm = torch.norm(
grad_bias.view(batch_size, -1), p='fro', dim=1, keepdim=True,
)
# shape for broadcasting multipliers with grad_output:
broadcast_shape = [1] * grad_output.dim()
broadcast_shape[0] = grad_output.size(0)
# aggregate the clipped per-example weight gradients:
multiplier = _get_multipliers(grad_weight_norm, ctx.clip)
multiplier = multiplier.reshape(broadcast_shape)
grad_weight = grad_weight.mul_(multiplier).sum(stats_dimensions)
# aggregate the clipped per-example weight gradients:
multiplier = _get_multipliers(grad_bias_norm, ctx.clip)
multiplier = multiplier.reshape(broadcast_shape)
grad_bias = grad_bias.mul_(multiplier).sum(stats_dimensions)
# add noise to gradients:
grad_weight += torch.randn_like(grad_weight) * ctx.clip * ctx.std
grad_bias += torch.randn_like(grad_bias) * ctx.clip * ctx.std
# return gradients:
return grad_input, grad_weight, grad_bias, None, None