in sockeye/training_pt.py [0:0]
def fit(self,
train_iter: data_io_pt.BaseParallelSampleIter,
validation_iter: data_io_pt.BaseParallelSampleIter,
checkpoint_decoder: Optional[checkpoint_decoder_pt.CheckpointDecoder] = None):
logger.info("Early stopping by optimizing '%s'", self.config.early_stopping_metric)
if utils.is_primary_worker() and self.config.early_stopping_metric in C.METRICS_REQUIRING_DECODER:
utils.check_condition(checkpoint_decoder is not None,
"%s requires CheckpointDecoder" % self.config.early_stopping_metric)
resume_training = os.path.exists(self.training_state_dirname)
if resume_training:
logger.info("Found partial training in '%s'. Resuming from saved state.", self.training_state_dirname)
self._load_training_state(train_iter)
else:
self.state = TrainState(self.config.early_stopping_metric)
if utils.is_primary_worker():
self.sockeye_model.save_config(self.config.output_dir)
self.sockeye_model.save_version(self.config.output_dir)
self.sockeye_model.save_parameters(self.current_params_fname)
logger.info("Training started.")
tic = time.time()
if self.config.max_checkpoints is not None:
self.config.max_updates = self.state.updates + self.config.max_checkpoints * self.config.checkpoint_interval
logger.info("Resetting max_updates to %d + %d * %d = %d in order to implement stopping "
"after (an additional) %d checkpoints.",
self.state.updates,
self.config.max_checkpoints,
self.config.checkpoint_interval,
self.config.max_updates,
self.config.max_checkpoints)
checkpoint_up_to_date = False
while True:
if self.config.max_epochs is not None and self.state.epoch == self.config.max_epochs:
logger.info("Maximum # of epochs (%s) reached.", self.config.max_epochs)
if not checkpoint_up_to_date:
time_cost = time.time() - tic
self._create_checkpoint(checkpoint_decoder, time_cost, train_iter, validation_iter)
break
if self.config.max_updates is not None and self.state.updates == self.config.max_updates:
logger.info("Maximum # of updates (%s) reached.", self.config.max_updates)
if not checkpoint_up_to_date:
time_cost = time.time() - tic
self._create_checkpoint(checkpoint_decoder, time_cost, train_iter, validation_iter)
break
if self.config.max_samples is not None and self.state.samples >= self.config.max_samples:
logger.info("Maximum # of samples (%s) reached", self.config.max_samples)
if not checkpoint_up_to_date:
time_cost = time.time() - tic
self._create_checkpoint(checkpoint_decoder, time_cost, train_iter, validation_iter)
break
did_grad_step = self._step(batch=train_iter.next())
checkpoint_up_to_date = checkpoint_up_to_date and not did_grad_step
if not train_iter.iter_next():
self.state.epoch += 1
train_iter.reset()
if self.state.updates > 0 and self.state.batches % (
self.config.checkpoint_interval * self.config.update_interval) == 0:
time_cost = time.time() - tic
self._create_checkpoint(checkpoint_decoder, time_cost, train_iter, validation_iter)
checkpoint_up_to_date = True
if self.config.max_seconds is not None and self.state.time_elapsed >= self.config.max_seconds:
logger.info("Maximum # of seconds (%s) reached. Training ran for %d seconds.",
self.config.max_seconds, self.state.time_elapsed)
break
if self.state.converged or self.state.diverged:
break
tic = time.time()
logger.info("Training finished%s. Best checkpoint: %d. Best validation %s: %.6f",
", can be continued later" if not self.state.converged else "",
self.state.best_checkpoint, self.state.early_stopping_metric, self.state.best_metric)
# Always keep the training state to allow continuing training with
# different stopping criteria
if utils.is_primary_worker():
self._cleanup(keep_training_state=True)
return self.state