def run()

in banding_removal/fastmri/training_loop_mixin.py [0:0]


    def run(self, epoch):
        args = self.args
        self.model.train()
        nbatches = self.nbatches
        interval = timer()
        percent_done = 0
        memory_gb = 0.0
        avg_losses = {}

        if self.args.nan_detection:
            autograd.set_detect_anomaly(True)

        for batch_idx, batch in enumerate(self.train_loader):
            self.batch_idx = batch_idx
            progress = epoch + batch_idx/nbatches

            logging_epoch = (batch_idx % args.log_interval == 0
                             or batch_idx == (nbatches-1))

            self.start_of_batch_hook(progress, logging_epoch)

            if batch_idx == 0:
                logging.info("Starting batch 0")
                sys.stdout.flush()

            def batch_closure(subbatch):
                nonlocal memory_gb

                result = self.training_loss(subbatch)

                if isinstance(result, tuple):
                    result, prediction, target = result
                else:
                    prediction = None # For backwards compatibility
                    target = None

                if isinstance(result, torch.Tensor):
                    # By default self.training_loss() returns a single tensor
                    loss_dict = {'train_loss': result}
                else:
                    # Perceptual loss will return a dict of losses where the main
                    # loss is 'train_loss'. This is for easily logging the parts
                    # composing the loss (eg, perceptual loss + l1)
                    loss_dict = result

                loss_dict, _, _, _ = self.additional_training_loss_terms(
                                        loss_dict, subbatch, prediction, target)

                loss = loss_dict['train_loss']

                # Memory usage is at its maximum right before backprop
                if logging_epoch and self.args.cuda:
                    memory_gb = torch.cuda.memory_allocated()/1000000000

                self.midbatch_hook(progress, logging_epoch)

                self.optimizer.zero_grad()
                self.backwards(loss)
                return loss, loss_dict

            if hasattr(self.optimizer, 'batch_step'):
                loss, loss_dict = self.optimizer.batch_step(batch, batch_closure=batch_closure)
            else:
                closure = lambda: batch_closure(batch)
                loss, loss_dict = self.optimizer.step(closure=closure)

            if args.debug:
                self.check_for_nan(loss)

            # Running average of all losses returned

            for name in loss_dict:
                loss_gpu = loss_dict[name]
                loss_cpu = loss_gpu.cpu().item()
                loss_dict[name] = loss_cpu
                if batch_idx == 0:
                    avg_losses[name] = loss_cpu
                elif batch_idx < 50:
                    avg_losses[name] = (batch_idx*avg_losses[name] + loss_cpu)/(batch_idx+1)
                else:
                    avg_losses[name] = 0.99*avg_losses[name] + 0.01*loss_cpu

            losses = {}
            for name in loss_dict:
                losses['instantaneous_' + name] = loss_dict[name]
                losses['average_' + name] = avg_losses[name]

            self.runinfo['train_fnames'].append(batch['fname'])
            self.training_loss_hook(progress, losses, logging_epoch)

            del losses

            if logging_epoch:
                mid = timer()
                new_percent_done = 100. * batch_idx / nbatches
                percent_change = new_percent_done - percent_done
                percent_done = new_percent_done
                if percent_done > 0:
                    inst_estimate =  math.ceil((mid - interval)/(percent_change/100))
                    inst_estimate = str(datetime.timedelta(seconds=inst_estimate))
                else:
                    inst_estimate = "unknown"

                logging.info('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, inst: {} Mem: {:2.1f}gb'.format(
                    epoch, batch_idx, nbatches,
                    100. * batch_idx / nbatches, loss.item(), inst_estimate,
                    memory_gb))
                interval = mid

                if self.args.break_early is not None and percent_done >= self.args.break_early:
                    break

            if self.args.debug_epoch:
                break