in utils.py [0:0]
def compute_full_grad(model, device, data_loader, loss_fn, lam=0):
full_grad = None
model.zero_grad()
for batch_idx, (data, target) in enumerate(data_loader):
data, target = data.to(device), target.to(device)
loss_with_reg(model, data, target, loss_fn, lam)
grad = params_to_vec(model.parameters(), grad=True)
if full_grad is None:
full_grad = grad * data.size(0) / len(data_loader.dataset)
else:
full_grad += grad * data.size(0) / len(data_loader.dataset)
model.zero_grad()
param_vec = params_to_vec(model.parameters())
return full_grad, param_vec