def train_eval_model()

in neural/train.py [0:0]


def train_eval_model(dataset,
                     model,
                     optimizer=None,
                     progress=True,
                     train=True,
                     save=False,
                     device="cpu",
                     batch_size=128,
                     permut_feature=None,
                     criterion=nn.MSELoss()):
    '''Train and Eval function.

    Inputs:
    ...
    - save: if True, the second output is a namedtuple with
            megs, [N, C, T]
            forcings, dict of values [N, 1, T]
            predictions, [N, C, T]  # if regularization, T can change
            lengths, [N]
            subjects, [N]
    '''

    dataloaded = data.DataLoader(dataset, batch_size=batch_size, shuffle=train)
    if train:
        desc = "train set"
        model.train()
    else:
        desc = "test set"
        model.eval()

    running_loss = 0

    dl_iter = iter(dataloaded)
    if progress:
        dl_iter = tqdm(dataloaded, leave=False, ncols=120, total=len(dataloaded), desc=desc)

    if save:
        saved = SavedEval([], [], [], [], [])

    batch_idx = 0

    for batch_idx, batch in enumerate(dl_iter):

        # Unpack batch and load unto device (e.g. gpu)
        meg, forcings, length, subject_id = batch  # [B, C, T], dict of values [B, C, 1], [B], [B]
        meg = meg.to(device)
        forcings = {k: v.to(device) for k, v in forcings.items()}
        subject_id = subject_id.to(device)
        true_subject_id = subject_id
        length = length.to(device)

        n_batches, channels, n_times = meg.size()

        meg_true = meg

        # Permute an input feature (to measure its importance at test time)
        if permut_feature is not None:
            permutation = th.randperm(n_batches, device=device)
            if permut_feature == "meg":
                permutation = permutation.view(-1, 1, 1).expand(-1, meg.size(1), meg.size(-1))
                meg = th.gather(meg, 0, permutation)
            elif permut_feature == "subject":
                subject_id = th.gather(subject_id, 0, permutation)
            else:
                forcing = forcings[permut_feature]
                forcings[permut_feature] = permute_forcing(forcings["first_mask"], forcing,
                                                           permutation)
        saved_forcings = forcings

        # Predict, evaluate loss, backprop
        meg_pred = model(meg, forcings, subject_id)
        loss_train = criterion(meg_pred, meg_true)
        loss = criterion(meg_pred[..., model.meg_init:], meg_true[..., model.meg_init:])
        running_loss += loss.item()

        if train:
            loss_train.backward()
            optimizer.step()
            optimizer.zero_grad()

        if save:
            # all quantities (meg, forcings, length, subject_id) saved in their original state,
            # except forcing which is saved in its permuted state
            saved.megs.append(meg_true.cpu())
            saved.forcings.append({k: v.cpu() for k, v in saved_forcings.items()})
            saved.predictions.append(meg_pred.detach().cpu())
            saved.lengths.append(length.cpu())
            saved.subjects.append(true_subject_id.cpu())
        if progress:
            dl_iter.set_postfix(loss=running_loss / (batch_idx + 1))

    n_batches = batch_idx + 1  # idx starts at 0
    running_loss /= n_batches  # average over batches

    if save:
        saved = SavedEval(
            megs=th.cat(saved.megs),
            forcings={k: th.cat([v[k] for v in saved.forcings])
                      for k in forcings},
            predictions=th.cat(saved.predictions),
            lengths=th.cat(saved.lengths),
            subjects=th.cat(saved.subjects))
    else:
        saved = None

    return running_loss, saved