def compute_full_grad()

in utils.py [0:0]


def compute_full_grad(model, device, data_loader, loss_fn, lam=0):
    full_grad = None
    model.zero_grad()
    for batch_idx, (data, target) in enumerate(data_loader):
        data, target = data.to(device), target.to(device)
        loss_with_reg(model, data, target, loss_fn, lam)
        grad = params_to_vec(model.parameters(), grad=True)
        if full_grad is None:
            full_grad = grad * data.size(0) / len(data_loader.dataset)
        else:
            full_grad += grad * data.size(0) / len(data_loader.dataset)
        model.zero_grad()
    param_vec = params_to_vec(model.parameters())
    return full_grad, param_vec