in experiment.py [0:0]
def trainvalidate(model,
stats,
epoch,
loader,
optimizer,
validation,
bp_var='objective',
metric_print_interval=5,
visualize_interval=100,
visdom_env_root='trainvalidate',
**kwargs):
if validation:
model.eval()
trainmode = 'val'
else:
model.train()
trainmode = 'train'
t_start = time.time()
# clear the visualisations on the first run in the epoch
clear_visualisations = True
# get the visdom env name
visdom_env_imgs = visdom_env_root + "_images_" + trainmode
n_batches = len(loader)
for it, batch in enumerate(loader):
last_iter = it == n_batches-1
# move to gpu where possible
net_input = get_net_input(batch)
# the forward pass
if (not validation):
optimizer.zero_grad()
preds = model(**net_input)
else:
with torch.no_grad():
preds = model(**net_input)
# make sure we dont overwrite something
assert not any(k in preds for k in net_input.keys())
preds.update(net_input) # merge everything into one big dict
# update the stats logger
stats.update(preds, time_start=t_start, stat_set=trainmode)
assert stats.it[trainmode] == it, "inconsistent stat iteration number!"
# print textual status update
if (it % metric_print_interval) == 0 or last_iter:
stats.print(stat_set=trainmode, max_it=n_batches)
# visualize results
if (visualize_interval > 0) and (it % visualize_interval) == 0:
model.visualize(visdom_env_imgs, trainmode,
preds, stats, clear_env=clear_visualisations)
clear_visualisations = False
# optimizer step
if (not validation):
loss = preds[bp_var]
loss.backward()
optimizer.step()