in banding_removal/fastmri/training_loop_mixin.py [0:0]
def run(self, epoch):
args = self.args
self.model.train()
nbatches = self.nbatches
interval = timer()
percent_done = 0
memory_gb = 0.0
avg_losses = {}
if self.args.nan_detection:
autograd.set_detect_anomaly(True)
for batch_idx, batch in enumerate(self.train_loader):
self.batch_idx = batch_idx
progress = epoch + batch_idx/nbatches
logging_epoch = (batch_idx % args.log_interval == 0
or batch_idx == (nbatches-1))
self.start_of_batch_hook(progress, logging_epoch)
if batch_idx == 0:
logging.info("Starting batch 0")
sys.stdout.flush()
def batch_closure(subbatch):
nonlocal memory_gb
result = self.training_loss(subbatch)
if isinstance(result, tuple):
result, prediction, target = result
else:
prediction = None # For backwards compatibility
target = None
if isinstance(result, torch.Tensor):
# By default self.training_loss() returns a single tensor
loss_dict = {'train_loss': result}
else:
# Perceptual loss will return a dict of losses where the main
# loss is 'train_loss'. This is for easily logging the parts
# composing the loss (eg, perceptual loss + l1)
loss_dict = result
loss_dict, _, _, _ = self.additional_training_loss_terms(
loss_dict, subbatch, prediction, target)
loss = loss_dict['train_loss']
# Memory usage is at its maximum right before backprop
if logging_epoch and self.args.cuda:
memory_gb = torch.cuda.memory_allocated()/1000000000
self.midbatch_hook(progress, logging_epoch)
self.optimizer.zero_grad()
self.backwards(loss)
return loss, loss_dict
if hasattr(self.optimizer, 'batch_step'):
loss, loss_dict = self.optimizer.batch_step(batch, batch_closure=batch_closure)
else:
closure = lambda: batch_closure(batch)
loss, loss_dict = self.optimizer.step(closure=closure)
if args.debug:
self.check_for_nan(loss)
# Running average of all losses returned
for name in loss_dict:
loss_gpu = loss_dict[name]
loss_cpu = loss_gpu.cpu().item()
loss_dict[name] = loss_cpu
if batch_idx == 0:
avg_losses[name] = loss_cpu
elif batch_idx < 50:
avg_losses[name] = (batch_idx*avg_losses[name] + loss_cpu)/(batch_idx+1)
else:
avg_losses[name] = 0.99*avg_losses[name] + 0.01*loss_cpu
losses = {}
for name in loss_dict:
losses['instantaneous_' + name] = loss_dict[name]
losses['average_' + name] = avg_losses[name]
self.runinfo['train_fnames'].append(batch['fname'])
self.training_loss_hook(progress, losses, logging_epoch)
del losses
if logging_epoch:
mid = timer()
new_percent_done = 100. * batch_idx / nbatches
percent_change = new_percent_done - percent_done
percent_done = new_percent_done
if percent_done > 0:
inst_estimate = math.ceil((mid - interval)/(percent_change/100))
inst_estimate = str(datetime.timedelta(seconds=inst_estimate))
else:
inst_estimate = "unknown"
logging.info('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, inst: {} Mem: {:2.1f}gb'.format(
epoch, batch_idx, nbatches,
100. * batch_idx / nbatches, loss.item(), inst_estimate,
memory_gb))
interval = mid
if self.args.break_early is not None and percent_done >= self.args.break_early:
break
if self.args.debug_epoch:
break