in heyhi/__init__.py [0:0]
def _build_slurm_executor(exp_handle, cfg):
executor = submitit.SlurmExecutor(folder=exp_handle.slurm_path)
assert cfg.num_gpus < 8 or cfg.num_gpus % 8 == 0, cfg.num_gpus
if cfg.num_gpus:
gpus = min(cfg.num_gpus, 8)
nodes = max(1, cfg.num_gpus // 8)
assert (
gpus * nodes == cfg.num_gpus
), "Must use 8 gpus per machine when multiple nodes are used."
else:
gpus = 0
nodes = 1
if cfg.single_task_per_node:
ntasks_per_node = 1
else:
ntasks_per_node = gpus
slurm_params = dict(
job_name=exp_handle.exp_id,
partition=cfg.partition,
time=int(cfg.hours * 60),
nodes=nodes,
num_gpus=gpus,
ntasks_per_node=ntasks_per_node,
mem=f"{cfg.mem_per_gpu * max(1, gpus)}GB",
signal_delay_s=90,
comment=cfg.comment or "",
)
if cfg.cpus_per_gpu:
slurm_params["cpus_per_task"] = cfg.cpus_per_gpu * gpus // ntasks_per_node
if cfg.volta32:
slurm_params["constraint"] = "volta32gb"
if cfg.pascal:
slurm_params["constraint"] = "pascal"
if cfg.volta:
slurm_params["constraint"] = "volta"
if is_aws():
slurm_params["mem"] = 0
slurm_params["cpus_per_task"] = 1
slurm_params["partition"] = "compute"
if "constraint" in slurm_params:
del slurm_params["constraint"]
logging.info("Slurm params: %s", slurm_params)
executor.update_parameters(**slurm_params)
return executor