in grok/measure.py [0:0]
def get_loss_and_grads(x, model, data_loader):
# if type(x).__module__ == np.__name__:
# x = torch.from_numpy(x).float()
# x = x.cuda()
model.eval()
x_start = 0
for p in model.parameters():
param_size = p.data.size()
param_idx = 1
for s in param_size:
param_idx *= s
x_part = x[x_start : x_start + param_idx]
p.data = torch.Tensor(x_part.reshape(param_size))
x_start += param_idx
batch_losses = []
batch_grads = []
for it, batch in enumerate(data_loader):
# Move data to correct device
# inputs = inputs.to(device)
# targets = targets.to(device)
with torch.set_grad_enabled(True):
# loss, grads = model(idx=inputs, targets=targets, grads=True)
loss, grads = model._step(batch=batch, batch_idx=1, train=True, grads=True)
# Todo: average over dataset
batch_losses.append(loss)
# batch_grads.append(None if grads is None else grads.cpu().numpy().astype(np.float64))
batch_grads.append(None if grads is None else grads)
mean_losses = torch.mean(torch.stack(batch_losses))
mean_grads = torch.mean(torch.stack(batch_grads), dim=0)
return (mean_losses, mean_grads.cpu().numpy().astype(np.float64))