def train_epoch()

in engine/training_engine.py [0:0]


    def train_epoch(self, epoch):
        time.sleep(2)  # To prevent possible deadlock during epoch transition
        train_stats = Statistics(metric_names=self.metric_names, is_master_node=self.is_master_node)

        self.model.train()
        accum_freq = self.accum_freq if epoch > self.accum_after_epoch else 1
        max_norm = getattr(self.opts, "common.grad_clip", None)

        self.optimizer.zero_grad()

        epoch_start_time = time.time()
        batch_load_start = time.time()
        for batch_id, batch in enumerate(self.train_loader):
            if self.train_iterations > self.max_iterations:
                self.max_iterations_reached = True
                return -1, -1

            batch_load_toc = time.time() - batch_load_start
            input_img, target_label = batch['image'], batch['label']
            # move data to device
            input_img = input_img.to(self.device)
            if isinstance(target_label, Dict):
                for k, v in target_label.items():
                    target_label[k] = v.to(self.device)
            else:
                target_label = target_label.to(self.device)

            batch_size = input_img.shape[0]

            # update the learning rate
            self.optimizer = self.scheduler.update_lr(optimizer=self.optimizer, epoch=epoch,
                                                      curr_iter=self.train_iterations)

            # adjust bn momentum
            if self.adjust_norm_mom is not None:
                self.adjust_norm_mom.adjust_momentum(model=self.model,
                                                     epoch=epoch,
                                                     iteration=self.train_iterations)

            with autocast(enabled=self.mixed_precision_training):
                # prediction
                pred_label = self.model(input_img)
                # compute loss
                loss = self.criteria(input_sample=input_img, prediction=pred_label, target=target_label)

                if isinstance(loss, torch.Tensor) and torch.isnan(loss):
                    import pdb
                    pdb.set_trace()

            # perform the backward pass with gradient accumulation [Optional]
            self.gradient_scalar.scale(loss).backward()

            if (batch_id + 1) % accum_freq == 0:
                if max_norm is not None:
                    # For gradient clipping, unscale the gradients and then clip them
                    self.gradient_scalar.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=max_norm)

                # optimizer step
                self.gradient_scalar.step(optimizer=self.optimizer)
                # update the scale for next batch
                self.gradient_scalar.update()
                # set grads to zero
                self.optimizer.zero_grad()

                if self.model_ema is not None:
                    self.model_ema.update_parameters(self.model)

            metrics = metric_monitor(pred_label=pred_label, target_label=target_label, loss=loss,
                                     use_distributed=self.use_distributed, metric_names=self.metric_names)

            train_stats.update(metric_vals=metrics, batch_time=batch_load_toc, n=batch_size)

            if batch_id % self.log_freq == 0 and self.is_master_node:
                lr = self.scheduler.retrieve_lr(self.optimizer)
                train_stats.iter_summary(epoch=epoch,
                                         n_processed_samples=self.train_iterations,
                                         total_samples=self.max_iterations,
                                         learning_rate=lr,
                                         elapsed_time=epoch_start_time)

            batch_load_start = time.time()
            self.train_iterations += 1

        avg_loss = train_stats.avg_statistics(metric_name='loss')
        train_stats.epoch_summary(epoch=epoch, stage="training")
        avg_ckpt_metric = train_stats.avg_statistics(metric_name=self.ckpt_metric)
        return avg_loss, avg_ckpt_metric