def torch_setup()

in ppo_ewma/torch_util.py [0:0]


def torch_setup(device_type=None, gpu_offset=0):
    """
    Setup torch to use the correct device and number of threads.  This should be called before `torch_init_process_group`

    Returns the torch device to use
    """
    from mpi4py import MPI
    import torch

    if device_type is None:
        device_type = "cuda" if torch.cuda.is_available() else "cpu"

    local_rank, local_size = _get_local_rank_size(MPI.COMM_WORLD)
    if device_type == "cuda":
        device_index = (local_rank + gpu_offset) % torch.cuda.device_count()
        torch.cuda.set_device(device_index)
    else:
        device_index = 0
    if "RCALL_NUM_CPU_PER_PROC" in os.environ:
        num_threads = int(os.environ["RCALL_NUM_CPU_PER_PROC"])
    else:
        num_threads = max(round(mp.cpu_count() // 2 / local_size), 1)
    torch.set_num_threads(num_threads)
    return torch.device(type=device_type, index=device_index)