in experiment.py [0:0]
def trainvalidate(model,
                  stats,
                  epoch,
                  loader,
                  optimizer,
                  validation,
                  bp_var='objective',
                  metric_print_interval=5,
                  visualize_interval=100,
                  visdom_env_root='trainvalidate',
                  **kwargs):
    if validation:
        model.eval()
        trainmode = 'val'
    else:
        model.train()
        trainmode = 'train'
    t_start = time.time()
    # clear the visualisations on the first run in the epoch
    clear_visualisations = True
    # get the visdom env name
    visdom_env_imgs = visdom_env_root + "_images_" + trainmode
    n_batches = len(loader)
    for it, batch in enumerate(loader):
        last_iter = it == n_batches-1
        # move to gpu where possible
        net_input = get_net_input(batch)
        # the forward pass
        if (not validation):
            optimizer.zero_grad()
            preds = model(**net_input)
        else:
            with torch.no_grad():
                preds = model(**net_input)
        # make sure we dont overwrite something
        assert not any(k in preds for k in net_input.keys())
        preds.update(net_input)  # merge everything into one big dict
        # update the stats logger
        stats.update(preds, time_start=t_start, stat_set=trainmode)
        assert stats.it[trainmode] == it, "inconsistent stat iteration number!"
        # print textual status update
        if (it % metric_print_interval) == 0 or last_iter:
            stats.print(stat_set=trainmode, max_it=n_batches)
        # visualize results
        if (visualize_interval > 0) and (it % visualize_interval) == 0:
            model.visualize(visdom_env_imgs, trainmode,
                            preds, stats, clear_env=clear_visualisations)
            clear_visualisations = False
        # optimizer step
        if (not validation):
            loss = preds[bp_var]
            loss.backward()
            optimizer.step()