in utils.py [0:0]
def per_example_gradient(extr, clf, x, y, loss_fn, include_linear=False):
logits, activations, linearCombs = clf(extr(x))
loss = loss_fn(logits, y)
loss.backward(retain_graph=True)
gradients = []
for module in list(next(extr.children()).children()):
grad = module.expanded_weight.grad * x.size(0)
gradients.append(grad.view(x.size(0), -1, grad.size(1), grad.size(2), grad.size(3)))
if module.expanded_bias is not None:
gradients.append(module.expanded_bias.grad.view(x.size(0), -1) * x.size(0))
if include_linear:
linearGrads = torch.autograd.grad(loss, linearCombs)
linearGrads = goodfellow_backprop(activations, linearGrads)
gradients = gradients + linearGrads
return loss, gradients