def _build_slurm_executor()

in heyhi/__init__.py [0:0]


def _build_slurm_executor(exp_handle, cfg):
    executor = submitit.SlurmExecutor(folder=exp_handle.slurm_path)
    assert cfg.num_gpus < 8 or cfg.num_gpus % 8 == 0, cfg.num_gpus
    if cfg.num_gpus:
        gpus = min(cfg.num_gpus, 8)
        nodes = max(1, cfg.num_gpus // 8)
        assert (
            gpus * nodes == cfg.num_gpus
        ), "Must use 8 gpus per machine when multiple nodes are used."
    else:
        gpus = 0
        nodes = 1

    if cfg.single_task_per_node:
        ntasks_per_node = 1
    else:
        ntasks_per_node = gpus

    slurm_params = dict(
        job_name=exp_handle.exp_id,
        partition=cfg.partition,
        time=int(cfg.hours * 60),
        nodes=nodes,
        num_gpus=gpus,
        ntasks_per_node=ntasks_per_node,
        mem=f"{cfg.mem_per_gpu * max(1, gpus)}GB",
        signal_delay_s=90,
        comment=cfg.comment or "",
    )
    if cfg.cpus_per_gpu:
        slurm_params["cpus_per_task"] = cfg.cpus_per_gpu * gpus // ntasks_per_node
    if cfg.volta32:
        slurm_params["constraint"] = "volta32gb"
    if cfg.pascal:
        slurm_params["constraint"] = "pascal"
    if cfg.volta:
        slurm_params["constraint"] = "volta"
    if is_aws():
        slurm_params["mem"] = 0
        slurm_params["cpus_per_task"] = 1
        slurm_params["partition"] = "compute"
        if "constraint" in slurm_params:
            del slurm_params["constraint"]
    logging.info("Slurm params: %s", slurm_params)
    executor.update_parameters(**slurm_params)
    return executor