def per_example_gradient()

in utils.py [0:0]


def per_example_gradient(extr, clf, x, y, loss_fn, include_linear=False):
    logits, activations, linearCombs = clf(extr(x))
    loss = loss_fn(logits, y)
    loss.backward(retain_graph=True)
    gradients = []
    for module in list(next(extr.children()).children()):
        grad = module.expanded_weight.grad * x.size(0)
        gradients.append(grad.view(x.size(0), -1, grad.size(1), grad.size(2), grad.size(3)))
        if module.expanded_bias is not None:
            gradients.append(module.expanded_bias.grad.view(x.size(0), -1) * x.size(0))
    if include_linear:
        linearGrads = torch.autograd.grad(loss, linearCombs)
        linearGrads = goodfellow_backprop(activations, linearGrads)
        gradients = gradients + linearGrads
    return loss, gradients