def __init__()

in graphlearn_torch/python/distributed/dist_neighbor_loader.py [0:0]


  def __init__(self,
               data: Optional[DistDataset],
               num_neighbors: NumNeighbors,
               input_nodes: InputNodes,
               batch_size: int = 1,
               shuffle: bool = False,
               drop_last: bool = False,
               with_edge: bool = False,
               with_weight: bool = False,
               edge_dir: Literal['in', 'out'] = 'out',
               collect_features: bool = False,
               to_device: Optional[torch.device] = None,
               random_seed: Optional[int] = None,
               worker_options: Optional[AllDistSamplingWorkerOptions] = None):

    if isinstance(input_nodes, tuple):
      input_type, input_seeds = input_nodes
    else:
      input_type, input_seeds = None, input_nodes

    if isinstance(worker_options, RemoteDistSamplingWorkerOptions):
      if isinstance(input_seeds, Split):
        input_data = RemoteNodeSplitSamplerInput(split=input_seeds, input_type=input_type)
        if isinstance(worker_options.server_rank, List):
          input_data = [input_data] * len(worker_options.server_rank)
      elif isinstance(input_seeds, List):
        input_data = []
        for elem in input_seeds:
          input_data.append(RemoteNodePathSamplerInput(node_path=elem, input_type=input_type))
      elif isinstance(input_seeds, str):
        input_data = RemoteNodePathSamplerInput(node_path=input_seeds, input_type=input_type)
      else:
        raise ValueError("Invalid input seeds")
    else:
      input_data = NodeSamplerInput(node=input_seeds, input_type=input_type)

    sampling_config = SamplingConfig(
      SamplingType.NODE, num_neighbors, batch_size, shuffle,
      drop_last, with_edge, collect_features, with_neg=False, 
      with_weight=with_weight, edge_dir=edge_dir, seed=random_seed
    )

    super().__init__(
      data, input_data, sampling_config, to_device, worker_options
    )