in graphlearn_torch/python/distributed/dist_neighbor_loader.py [0:0]
def __init__(self,
data: Optional[DistDataset],
num_neighbors: NumNeighbors,
input_nodes: InputNodes,
batch_size: int = 1,
shuffle: bool = False,
drop_last: bool = False,
with_edge: bool = False,
with_weight: bool = False,
edge_dir: Literal['in', 'out'] = 'out',
collect_features: bool = False,
to_device: Optional[torch.device] = None,
random_seed: Optional[int] = None,
worker_options: Optional[AllDistSamplingWorkerOptions] = None):
if isinstance(input_nodes, tuple):
input_type, input_seeds = input_nodes
else:
input_type, input_seeds = None, input_nodes
if isinstance(worker_options, RemoteDistSamplingWorkerOptions):
if isinstance(input_seeds, Split):
input_data = RemoteNodeSplitSamplerInput(split=input_seeds, input_type=input_type)
if isinstance(worker_options.server_rank, List):
input_data = [input_data] * len(worker_options.server_rank)
elif isinstance(input_seeds, List):
input_data = []
for elem in input_seeds:
input_data.append(RemoteNodePathSamplerInput(node_path=elem, input_type=input_type))
elif isinstance(input_seeds, str):
input_data = RemoteNodePathSamplerInput(node_path=input_seeds, input_type=input_type)
else:
raise ValueError("Invalid input seeds")
else:
input_data = NodeSamplerInput(node=input_seeds, input_type=input_type)
sampling_config = SamplingConfig(
SamplingType.NODE, num_neighbors, batch_size, shuffle,
drop_last, with_edge, collect_features, with_neg=False,
with_weight=with_weight, edge_dir=edge_dir, seed=random_seed
)
super().__init__(
data, input_data, sampling_config, to_device, worker_options
)