in dora/lightning.py [0:0]
def get_trainer(*args, add_dora_logger=True, no_unfinished_epochs=True, **kwargs):
"""Return a PL trainer, adding the necessary glue code to make everything works.
The arguments are exactly the same as for `pytorch_lightning.trainer.Trainer`,
with a few extras documented after.
..note:: You should not pass `gpus=` or `num_nodes=` arguments as those will be filled by Dora.
Args:
add_dora_logger (bool): if True, adds a Dora Logger to automatically
forward the metrics (those logged with per_epoch=True), otherwise
pushing metrics will be up to you.
no_unfinished_epochs (bool): if True, deactivates SLURM signal handling
by PL, which can result in half finished epoch with each interruption.
It is recommended to instead dump a checkpoint every epoch and resume
from that one so that training is reliable.
"""
if not is_xp():
raise RuntimeError("This can only be called from inside a Dora XP.")
# Convert all to kwargs, add [None] dummy for self which is missing.
init = Trainer.__init__
while hasattr(init, '__wrapped__'):
init = init.__wrapped__
kwargs = inspect.getcallargs(init, [None] + list(args), **kwargs)
del kwargs['self']
plugins = kwargs.pop("plugins") or []
env = DoraEnvironment()
gpus = min(torch.cuda.device_count(), env.world_size())
if env.world_size() > 1:
plugins += [env, 'dora_ddp']
kwargs['plugins'] = plugins
callbacks = kwargs.pop("callbacks", [])
callbacks.append(RestoreDoraHistory())
kwargs['callbacks'] = callbacks
if kwargs['gpus'] is not None:
raise RuntimeError("You cannot specify the number of GPUs, as this is provided by Dora.")
if kwargs['num_nodes'] != 1:
raise RuntimeError("You cannot specify the number of nodes, as this is provided by Dora.")
kwargs['gpus'] = gpus
kwargs['num_nodes'] = env.spec.num_nodes
kwargs['default_root_dir'] = get_xp().folder
if add_dora_logger:
logger = kwargs['logger']
if logger is True:
version = os.environ.get('PL_EXP_VERSION')
if version is None:
version = os.environ.get('SLURM_JOB_ID')
# Create default logger as in PL logger_connector.py
logger = TensorBoardLogger(
save_dir=get_xp().folder, version=version, name='lightning_logs')
if not isinstance(logger, tp.Iterable):
logger = [logger]
dora_logger = DoraHistoryLogger()
kwargs['callbacks'].append(_ArmDoraLogger(dora_logger))
logger.append(dora_logger)
kwargs['logger'] = logger
trainer = Trainer(**kwargs)
if no_unfinished_epochs:
trainer.slurm_connector = _DummySLURMConnector()
return trainer