in graphlearn_torch/python/distributed/dist_options.py [0:0]
def __init__(self,
num_workers: int = 1,
worker_devices: Optional[List[torch.device]] = None,
worker_concurrency: int = 1,
master_addr: Optional[str] = None,
master_port: Optional[Union[str, int]] = None,
num_rpc_threads: Optional[int] = None,
rpc_timeout: float = 180):
self.num_workers = num_workers
# Not sure yet, will be calculated later.
self.worker_world_size = None
self.worker_ranks = None
if worker_devices is None:
self.worker_devices = None
elif isinstance(worker_devices, list) or isinstance(worker_devices, tuple):
assert len(worker_devices) == self.num_workers
self.worker_devices = list(worker_devices)
else:
self.worker_devices = [worker_devices] * self.num_workers
# Worker concurrency should not exceed 32.
self.worker_concurrency = max(worker_concurrency, 1)
self.worker_concurrency = min(self.worker_concurrency, 32)
if master_addr is not None:
self.master_addr = str(master_addr)
elif os.environ.get('MASTER_ADDR') is not None:
self.master_addr = os.environ['MASTER_ADDR']
else:
raise ValueError(f"'{self.__class__.__name__}': missing master address "
"for rpc communication, try to provide it or set it "
"with environment variable 'MASTER_ADDR'")
if master_port is not None:
self.master_port = int(master_port)
elif os.environ.get('MASTER_PORT') is not None:
self.master_port = int(os.environ['MASTER_PORT']) + 1
else:
raise ValueError(f"'{self.__class__.__name__}': missing master port "
"for rpc communication, try to provide it or set it "
"with environment variable 'MASTER_ADDR'")
self.num_rpc_threads = num_rpc_threads
if self.num_rpc_threads is not None:
assert self.num_rpc_threads > 0
self.rpc_timeout = rpc_timeout