def __init__()

in graphlearn_torch/python/distributed/dist_options.py [0:0]


  def __init__(self,
               server_rank: Optional[Union[int, List[int]]] = None,
               num_workers: int = 1,
               worker_devices: Optional[List[torch.device]] = None,
               worker_concurrency: int = 4,
               master_addr: Optional[str] = None,
               master_port: Optional[Union[str, int]] = None,
               num_rpc_threads: Optional[int] = None,
               rpc_timeout: float = 180,
               buffer_size: Optional[Union[int, str]] = None,
               prefetch_size: int = 4,
               worker_key: str = None,
               glt_graph = None,
               workload_type: Optional[Literal['train', 'validate', 'test']] = None,
               use_all2all: bool = False):
    # glt_graph is used in GraphScope side to get parameters
    if glt_graph:
      if not workload_type:
        raise ValueError(f"'{self.__class__.__name__}': missing workload_type ")
      master_addr = glt_graph.master_addr
      if workload_type == 'train':
        master_port = glt_graph.train_loader_master_port
      elif workload_type == 'validate':
        master_port = glt_graph.val_loader_master_port
      elif workload_type == 'test':
        master_port = glt_graph.test_loader_master_port
      worker_key = str(master_port)
    
    super().__init__(num_workers, worker_devices, worker_concurrency,
                     master_addr, master_port, num_rpc_threads, rpc_timeout)
    if server_rank is not None:
      self.server_rank = server_rank
    else:
      self.server_rank = assign_server_by_order()
    self.buffer_capacity = self.num_workers * self.worker_concurrency
    if buffer_size is None:
      self.buffer_size = f'{self.num_workers * 64}MB'
    else:
      self.buffer_size = buffer_size

    self.prefetch_size = prefetch_size
    if self.prefetch_size > self.buffer_capacity:
      raise ValueError(f"'{self.__class__.__name__}': the prefetch count "
                       f"{self.prefetch_size} exceeds the buffer capacity "
                       f"{self.buffer_capacity}")
    self.worker_key = worker_key
    self.use_all2all = use_all2all