def trainvalidate()

in experiment.py [0:0]


def trainvalidate(model,
                  stats,
                  epoch,
                  loader,
                  optimizer,
                  validation,
                  bp_var='objective',
                  metric_print_interval=5,
                  visualize_interval=100,
                  visdom_env_root='trainvalidate',
                  **kwargs):

    if validation:
        model.eval()
        trainmode = 'val'
    else:
        model.train()
        trainmode = 'train'

    t_start = time.time()

    # clear the visualisations on the first run in the epoch
    clear_visualisations = True

    # get the visdom env name
    visdom_env_imgs = visdom_env_root + "_images_" + trainmode

    n_batches = len(loader)
    for it, batch in enumerate(loader):

        last_iter = it == n_batches-1

        # move to gpu where possible
        net_input = get_net_input(batch)

        # the forward pass
        if (not validation):
            optimizer.zero_grad()
            preds = model(**net_input)
        else:
            with torch.no_grad():
                preds = model(**net_input)

        # make sure we dont overwrite something
        assert not any(k in preds for k in net_input.keys())
        preds.update(net_input)  # merge everything into one big dict

        # update the stats logger
        stats.update(preds, time_start=t_start, stat_set=trainmode)
        assert stats.it[trainmode] == it, "inconsistent stat iteration number!"

        # print textual status update
        if (it % metric_print_interval) == 0 or last_iter:
            stats.print(stat_set=trainmode, max_it=n_batches)

        # visualize results
        if (visualize_interval > 0) and (it % visualize_interval) == 0:
            model.visualize(visdom_env_imgs, trainmode,
                            preds, stats, clear_env=clear_visualisations)
            clear_visualisations = False

        # optimizer step
        if (not validation):
            loss = preds[bp_var]
            loss.backward()
            optimizer.step()