in graphlearn_torch/python/distributed/dist_subgraph_loader.py [0:0]
def __init__(self,
data: Optional[DistDataset],
input_nodes: InputNodes,
num_neighbors: Optional[NumNeighbors] = None,
batch_size: int = 1,
shuffle: bool = False,
drop_last: bool = False,
with_edge: bool = False,
with_weight: bool = False,
edge_dir: Literal['in', 'out'] = 'out',
collect_features: bool = False,
to_device: Optional[torch.device] = None,
random_seed: Optional[int] = None,
worker_options: Optional[AllDistSamplingWorkerOptions] = None):
if isinstance(input_nodes, tuple):
input_type, input_seeds = input_nodes
else:
input_type, input_seeds = None, input_nodes
input_data = NodeSamplerInput(node=input_seeds, input_type=input_type)
# TODO: currently only support out-sample
sampling_config = SamplingConfig(
SamplingType.SUBGRAPH, num_neighbors, batch_size, shuffle,
drop_last, with_edge, collect_features, with_neg=False,
with_weight=with_weight, edge_dir=edge_dir, seed=random_seed
)
super().__init__(
data, input_data, sampling_config, to_device, worker_options
)