in torchrecipes/core/base_train_app.py [0:0]
def train(self) -> TrainOutput:
trainer, log_params = self._get_trainer()
start_time = time.monotonic()
got_exception = None
try:
trainer.fit(self.module, datamodule=self.datamodule)
except Exception as ex:
got_exception = ex
# log trainer status to Scuba and Hive
total_run_time = int(time.monotonic() - start_time)
log_params["global_rank"] = trainer.global_rank
log_params["world_size"] = trainer.world_size
log_params["total_run_time"] = total_run_time
if got_exception is None:
log_params["run_status"] = JobStatus.COMPLETED.value
log_run(**log_params)
else:
log_params["error_message"] = str(got_exception)
log_params["stacktrace"] = traceback.format_stack()
log_params["run_status"] = JobStatus.FAILED.value
log_run(**log_params)
raise got_exception
best_model_path = getattr(self._checkpoint_callback, "best_model_path", None)
return TrainOutput(
tensorboard_log_dir=self.log_dir, best_model_path=best_model_path
)