in mmf/trainers/core/training_loop.py [0:0]
def run_training_epoch(self) -> None:
should_break = False
while self.num_updates < self.max_updates and not should_break:
self.current_epoch += 1
registry.register("current_epoch", self.current_epoch)
# Seed the sampler in case if it is distributed
self.dataset_loader.seed_sampler("train", self.current_epoch)
# For iterable datasets we cannot determine length of dataset properly.
# For those cases we set num_remaining_batches to be the (number of
# updates remaining x update_frequency)
num_remaining_batches = (
(
(self.max_updates - self.num_updates)
* self.training_config.update_frequency
)
if isinstance(
self.train_loader.current_dataset, torch.utils.data.IterableDataset
)
else len(self.train_loader)
)
should_start_update = True
for idx, batch in enumerate(self.train_loader):
if should_start_update:
combined_report = None
self._start_update()
num_batches_for_this_update = min(
self.training_config.update_frequency, num_remaining_batches
)
should_start_update = False
self.current_iteration += 1
# batch execution starts here
self.on_batch_start()
self.profile("Batch load time")
report = self.run_training_batch(batch, num_batches_for_this_update)
report = report.detach()
# accumulate necessary params (including loss) for metric calculation
if combined_report is None:
combined_report = report
else:
combined_report.accumulate_tensor_fields_and_loss(
report, self.metrics.required_params
)
combined_report.batch_size += report.batch_size
# batch execution ends here
self.on_batch_end(report=combined_report, meter=self.meter)
# check if an update has finished or if it is the last, if no continue
if (
(idx + 1) % self.training_config.update_frequency
and num_remaining_batches != num_batches_for_this_update
):
continue
self._finish_update()
should_start_update = True
should_log = False
if self.num_updates % self.logistics_callback.log_interval == 0:
should_log = True
# Calculate metrics every log interval for debugging
if self.training_config.evaluate_metrics:
combined_report.metrics = self.metrics(
combined_report, combined_report
)
self.meter.update_from_report(combined_report)
self.on_update_end(
report=combined_report, meter=self.meter, should_log=should_log
)
num_remaining_batches -= num_batches_for_this_update
# Check if training should be stopped
should_break = False
if self.num_updates % self.training_config.evaluation_interval == 0:
# Validation begin callbacks
self.on_validation_start()
logger.info("Evaluation time. Running on full validation set...")
# Validation and Early stopping
# Create a new meter for this case
report, meter = self.evaluation_loop("val")
# Validation end callbacks
stop = self.early_stop_callback.on_validation_end(
report=report, meter=meter
)
self.on_validation_end(report=report, meter=meter)
gc.collect()
if "cuda" in str(self.device):
torch.cuda.empty_cache()
if stop is True:
logger.info("Early stopping activated")
should_break = True
if self.num_updates >= self.max_updates:
should_break = True
if should_break:
break