in graphlearn_torch/python/distributed/dist_options.py [0:0]
def __init__(self,
server_rank: Optional[Union[int, List[int]]] = None,
num_workers: int = 1,
worker_devices: Optional[List[torch.device]] = None,
worker_concurrency: int = 4,
master_addr: Optional[str] = None,
master_port: Optional[Union[str, int]] = None,
num_rpc_threads: Optional[int] = None,
rpc_timeout: float = 180,
buffer_size: Optional[Union[int, str]] = None,
prefetch_size: int = 4,
worker_key: str = None,
glt_graph = None,
workload_type: Optional[Literal['train', 'validate', 'test']] = None,
use_all2all: bool = False):
# glt_graph is used in GraphScope side to get parameters
if glt_graph:
if not workload_type:
raise ValueError(f"'{self.__class__.__name__}': missing workload_type ")
master_addr = glt_graph.master_addr
if workload_type == 'train':
master_port = glt_graph.train_loader_master_port
elif workload_type == 'validate':
master_port = glt_graph.val_loader_master_port
elif workload_type == 'test':
master_port = glt_graph.test_loader_master_port
worker_key = str(master_port)
super().__init__(num_workers, worker_devices, worker_concurrency,
master_addr, master_port, num_rpc_threads, rpc_timeout)
if server_rank is not None:
self.server_rank = server_rank
else:
self.server_rank = assign_server_by_order()
self.buffer_capacity = self.num_workers * self.worker_concurrency
if buffer_size is None:
self.buffer_size = f'{self.num_workers * 64}MB'
else:
self.buffer_size = buffer_size
self.prefetch_size = prefetch_size
if self.prefetch_size > self.buffer_capacity:
raise ValueError(f"'{self.__class__.__name__}': the prefetch count "
f"{self.prefetch_size} exceeds the buffer capacity "
f"{self.buffer_capacity}")
self.worker_key = worker_key
self.use_all2all = use_all2all