in ppo_ewma/torch_util.py [0:0]
def torch_setup(device_type=None, gpu_offset=0):
"""
Setup torch to use the correct device and number of threads. This should be called before `torch_init_process_group`
Returns the torch device to use
"""
from mpi4py import MPI
import torch
if device_type is None:
device_type = "cuda" if torch.cuda.is_available() else "cpu"
local_rank, local_size = _get_local_rank_size(MPI.COMM_WORLD)
if device_type == "cuda":
device_index = (local_rank + gpu_offset) % torch.cuda.device_count()
torch.cuda.set_device(device_index)
else:
device_index = 0
if "RCALL_NUM_CPU_PER_PROC" in os.environ:
num_threads = int(os.environ["RCALL_NUM_CPU_PER_PROC"])
else:
num_threads = max(round(mp.cpu_count() // 2 / local_size), 1)
torch.set_num_threads(num_threads)
return torch.device(type=device_type, index=device_index)