def __init__()

in graphlearn_torch/python/distributed/dist_options.py [0:0]


  def __init__(self,
               num_workers: int = 1,
               worker_devices: Optional[List[torch.device]] = None,
               worker_concurrency: int = 1,
               master_addr: Optional[str] = None,
               master_port: Optional[Union[str, int]] = None,
               num_rpc_threads: Optional[int] = None,
               rpc_timeout: float = 180):
    self.num_workers = num_workers

    # Not sure yet, will be calculated later.
    self.worker_world_size = None
    self.worker_ranks = None

    if worker_devices is None:
      self.worker_devices = None
    elif isinstance(worker_devices, list) or isinstance(worker_devices, tuple):
      assert len(worker_devices) == self.num_workers
      self.worker_devices = list(worker_devices)
    else:
      self.worker_devices = [worker_devices] * self.num_workers

    # Worker concurrency should not exceed 32.
    self.worker_concurrency = max(worker_concurrency, 1)
    self.worker_concurrency = min(self.worker_concurrency, 32)

    if master_addr is not None:
      self.master_addr = str(master_addr)
    elif os.environ.get('MASTER_ADDR') is not None:
      self.master_addr = os.environ['MASTER_ADDR']
    else:
      raise ValueError(f"'{self.__class__.__name__}': missing master address "
                       "for rpc communication, try to provide it or set it "
                       "with environment variable 'MASTER_ADDR'")

    if master_port is not None:
      self.master_port = int(master_port)
    elif os.environ.get('MASTER_PORT') is not None:
      self.master_port = int(os.environ['MASTER_PORT']) + 1
    else:
      raise ValueError(f"'{self.__class__.__name__}': missing master port "
                       "for rpc communication, try to provide it or set it "
                       "with environment variable 'MASTER_ADDR'")

    self.num_rpc_threads = num_rpc_threads
    if self.num_rpc_threads is not None:
      assert self.num_rpc_threads > 0
    self.rpc_timeout = rpc_timeout