def get_loss_and_grads()

in grok/measure.py [0:0]


def get_loss_and_grads(x, model, data_loader):

    # if type(x).__module__ == np.__name__:
    #     x = torch.from_numpy(x).float()
    #     x = x.cuda()

    model.eval()

    x_start = 0
    for p in model.parameters():
        param_size = p.data.size()
        param_idx = 1
        for s in param_size:
            param_idx *= s
        x_part = x[x_start : x_start + param_idx]
        p.data = torch.Tensor(x_part.reshape(param_size))
        x_start += param_idx

    batch_losses = []
    batch_grads = []
    for it, batch in enumerate(data_loader):

        # Move data to correct device
        # inputs = inputs.to(device)
        # targets = targets.to(device)

        with torch.set_grad_enabled(True):
            # loss, grads = model(idx=inputs, targets=targets, grads=True)
            loss, grads = model._step(batch=batch, batch_idx=1, train=True, grads=True)

        # Todo: average over dataset
        batch_losses.append(loss)
        # batch_grads.append(None if grads is None else grads.cpu().numpy().astype(np.float64))
        batch_grads.append(None if grads is None else grads)

    mean_losses = torch.mean(torch.stack(batch_losses))
    mean_grads = torch.mean(torch.stack(batch_grads), dim=0)

    return (mean_losses, mean_grads.cpu().numpy().astype(np.float64))