in dora/shep.py [0:0]
def _get_submitit_executor(self, name: str, folder: Path,
slurm_config: SlurmConfig) -> submitit.SlurmExecutor:
os.environ['SLURM_KILL_BAD_EXIT'] = '1' # Kill the job if any of the task fails
kwargs = dict(slurm_config.__dict__)
executor = submitit.SlurmExecutor(
folder=folder, max_num_timeout=kwargs.pop('max_num_timeout'))
gpus = slurm_config.gpus
if gpus > 8:
if gpus % 8 != 0:
raise ValueError("Can only take <= 8 gpus, or multiple of 8 gpus")
kwargs['nodes'] = gpus // 8
gpus_per_node = 8
else:
gpus_per_node = gpus
kwargs['nodes'] = 1
mem = slurm_config.mem_per_gpu * gpus_per_node
kwargs['mem'] = f"{mem}GB"
if slurm_config.one_task_per_node:
kwargs['gpus_per_task'] = gpus_per_node
kwargs['ntasks_per_node'] = 1
if slurm_config.cpus_per_task is None:
kwargs['cpus_per_task'] = gpus_per_node * slurm_config.cpus_per_gpu
else:
kwargs['gpus_per_task'] = 1
kwargs['ntasks_per_node'] = gpus_per_node
if slurm_config.cpus_per_task is None:
kwargs['cpus_per_task'] = slurm_config.cpus_per_gpu
del kwargs['gpus']
del kwargs['mem_per_gpu']
del kwargs['cpus_per_gpu']
del kwargs['one_task_per_node']
logger.debug("Slurm parameters %r", kwargs)
executor.update_parameters(
job_name=name,
stderr_to_stdout=True,
**kwargs)
return executor