def run_training_epoch()

in mmf/trainers/core/training_loop.py [0:0]


    def run_training_epoch(self) -> None:
        should_break = False
        while self.num_updates < self.max_updates and not should_break:
            self.current_epoch += 1
            registry.register("current_epoch", self.current_epoch)

            # Seed the sampler in case if it is distributed
            self.dataset_loader.seed_sampler("train", self.current_epoch)

            # For iterable datasets we cannot determine length of dataset properly.
            # For those cases we set num_remaining_batches to be the (number of
            # updates remaining x update_frequency)
            num_remaining_batches = (
                (
                    (self.max_updates - self.num_updates)
                    * self.training_config.update_frequency
                )
                if isinstance(
                    self.train_loader.current_dataset, torch.utils.data.IterableDataset
                )
                else len(self.train_loader)
            )

            should_start_update = True
            for idx, batch in enumerate(self.train_loader):
                if should_start_update:
                    combined_report = None
                    self._start_update()
                    num_batches_for_this_update = min(
                        self.training_config.update_frequency, num_remaining_batches
                    )
                    should_start_update = False

                self.current_iteration += 1
                # batch execution starts here
                self.on_batch_start()
                self.profile("Batch load time")

                report = self.run_training_batch(batch, num_batches_for_this_update)
                report = report.detach()

                # accumulate necessary params (including loss) for metric calculation
                if combined_report is None:
                    combined_report = report
                else:
                    combined_report.accumulate_tensor_fields_and_loss(
                        report, self.metrics.required_params
                    )
                    combined_report.batch_size += report.batch_size

                # batch execution ends here
                self.on_batch_end(report=combined_report, meter=self.meter)

                # check if an update has finished or if it is the last, if no continue
                if (
                    (idx + 1) % self.training_config.update_frequency
                    and num_remaining_batches != num_batches_for_this_update
                ):
                    continue

                self._finish_update()
                should_start_update = True

                should_log = False
                if self.num_updates % self.logistics_callback.log_interval == 0:
                    should_log = True
                    # Calculate metrics every log interval for debugging
                    if self.training_config.evaluate_metrics:
                        combined_report.metrics = self.metrics(
                            combined_report, combined_report
                        )
                    self.meter.update_from_report(combined_report)

                self.on_update_end(
                    report=combined_report, meter=self.meter, should_log=should_log
                )

                num_remaining_batches -= num_batches_for_this_update

                # Check if training should be stopped
                should_break = False

                if self.num_updates % self.training_config.evaluation_interval == 0:
                    # Validation begin callbacks
                    self.on_validation_start()

                    logger.info("Evaluation time. Running on full validation set...")
                    # Validation and Early stopping
                    # Create a new meter for this case
                    report, meter = self.evaluation_loop("val")

                    # Validation end callbacks
                    stop = self.early_stop_callback.on_validation_end(
                        report=report, meter=meter
                    )
                    self.on_validation_end(report=report, meter=meter)

                    gc.collect()

                    if "cuda" in str(self.device):
                        torch.cuda.empty_cache()

                    if stop is True:
                        logger.info("Early stopping activated")
                        should_break = True

                if self.num_updates >= self.max_updates:
                    should_break = True

                if should_break:
                    break