in graphlearn_torch/python/distributed/dist_link_neighbor_loader.py [0:0]
def __init__(self,
data: Optional[DistDataset],
num_neighbors: NumNeighbors,
batch_size: int = 1,
edge_label_index: InputEdges = None,
edge_label: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None,
shuffle: bool = False,
drop_last: bool = False,
with_edge: bool = False,
with_weight: bool = False,
edge_dir: Literal['in', 'out'] = 'out',
collect_features: bool = False,
to_device: Optional[torch.device] = None,
random_seed: Optional[int] = None,
worker_options: Optional[AllDistSamplingWorkerOptions] = None):
# Get edge type (or `None` for homogeneous graphs):
input_type, edge_label_index = get_edge_label_index(
data, edge_label_index)
with_neg = neg_sampling is not None
self.neg_sampling = NegativeSampling.cast(neg_sampling)
if (self.neg_sampling is not None and self.neg_sampling.is_binary()
and edge_label is not None and edge_label.min() == 0):
# Increment labels such that `zero` now denotes "negative".
edge_label = edge_label + 1
if (self.neg_sampling is not None and self.neg_sampling.is_triplet()
and edge_label is not None):
raise ValueError("'edge_label' needs to be undefined for "
"'triplet'-based negative sampling. Please use "
"`src_index`, `dst_pos_index` and "
"`neg_pos_index` of the returned mini-batch "
"instead to differentiate between positive and "
"negative samples.")
input_data = EdgeSamplerInput(
row=edge_label_index[0].clone(),
col=edge_label_index[1].clone(),
label=edge_label,
input_type=input_type,
neg_sampling=self.neg_sampling,
)
sampling_config = SamplingConfig(
SamplingType.LINK, num_neighbors, batch_size, shuffle,
drop_last, with_edge, collect_features, with_neg,
with_weight=with_weight, edge_dir=edge_dir, seed=random_seed
)
super().__init__(
data, input_data, sampling_config, to_device, worker_options
)