def __init__()

in graphlearn_torch/python/distributed/dist_link_neighbor_loader.py [0:0]


  def __init__(self,
               data: Optional[DistDataset],
               num_neighbors: NumNeighbors,
               batch_size: int = 1,
               edge_label_index: InputEdges = None,
               edge_label: Optional[torch.Tensor] = None,
               neg_sampling: Optional[NegativeSampling] = None,
               shuffle: bool = False,
               drop_last: bool = False,
               with_edge: bool = False,
               with_weight: bool = False,
               edge_dir: Literal['in', 'out'] = 'out',
               collect_features: bool = False,
               to_device: Optional[torch.device] = None,
               random_seed: Optional[int] = None,
               worker_options: Optional[AllDistSamplingWorkerOptions] = None):
    # Get edge type (or `None` for homogeneous graphs):
    input_type, edge_label_index = get_edge_label_index(
        data, edge_label_index)
    with_neg = neg_sampling is not None
    self.neg_sampling = NegativeSampling.cast(neg_sampling)

    if (self.neg_sampling is not None and self.neg_sampling.is_binary()
            and edge_label is not None and edge_label.min() == 0):
      # Increment labels such that `zero` now denotes "negative".
      edge_label = edge_label + 1

    if (self.neg_sampling is not None and self.neg_sampling.is_triplet()
        and edge_label is not None):
      raise ValueError("'edge_label' needs to be undefined for "
                       "'triplet'-based negative sampling. Please use "
                       "`src_index`, `dst_pos_index` and "
                       "`neg_pos_index` of the returned mini-batch "
                       "instead to differentiate between positive and "
                       "negative samples.")

    input_data = EdgeSamplerInput(
      row=edge_label_index[0].clone(),
      col=edge_label_index[1].clone(),
      label=edge_label,
      input_type=input_type,
      neg_sampling=self.neg_sampling,
    )

    sampling_config = SamplingConfig(
      SamplingType.LINK, num_neighbors, batch_size, shuffle,
      drop_last, with_edge, collect_features, with_neg, 
      with_weight=with_weight, edge_dir=edge_dir, seed=random_seed
    )

    super().__init__(
      data, input_data, sampling_config, to_device, worker_options
    )