in fast_grad/gradient_funcs.py [0:0]
def naive(model, X, y):
"""
Computes the predictions in a full-batch fasion,
then call backward on the individual losses
"""
grad_list = []
logits, _, _ = model.forward(X)
N = X.shape[0]
for n in range(N):
model.zero_grad()
loss = F.binary_cross_entropy_with_logits(logits[n], y[n].view(-1,))
loss.backward(retain_graph=True)
grad_list.append(list([p.grad.clone() for p in model.parameters()]))
grads = []
for p_id in range(len(list(model.parameters()))):
grads.append(torch.cat([grad_list[n][p_id].unsqueeze(0) for n in range(N)]))
return grads