def __init__()

in graphlearn_torch/python/distributed/dist_neighbor_sampler.py [0:0]


  def __init__(self,
               data: DistDataset,
               num_neighbors: Optional[NumNeighbors] = None,
               with_edge: bool = False,
               with_neg: bool = False,
               with_weight: bool = False,
               edge_dir: Literal['in', 'out'] = 'out',
               collect_features: bool = False,
               channel: Optional[ChannelBase] = None,
               use_all2all: bool = False,
               concurrency: int = 1,
               device: Optional[torch.device] = None,
               seed:int = None):
    self.data = data
    self.use_all2all = use_all2all
    self.num_neighbors = num_neighbors
    self.max_input_size = 0
    self.with_edge = with_edge
    self.with_neg = with_neg
    self.with_weight = with_weight
    self.edge_dir = edge_dir
    self.collect_features = collect_features
    self.channel = channel
    self.concurrency = concurrency
    self.device = get_available_device(device)
    self.seed = seed

    if isinstance(data, DistDataset):
      partition2workers = rpc_sync_data_partitions(
        num_data_partitions=self.data.num_partitions,
        current_partition_idx=self.data.partition_idx
      )
      self.rpc_router = RpcDataPartitionRouter(partition2workers)

      self.dist_graph = DistGraph(
        data.num_partitions, data.partition_idx,
        data.graph, data.node_pb, data.edge_pb
      )

      self.dist_node_feature = None
      self.dist_edge_feature = None
      if self.collect_features:
        if data.node_features is not None:
          self.dist_node_feature = DistFeature(
            data.num_partitions, data.partition_idx,
            data.node_features, data.node_feat_pb,
            local_only=False, rpc_router=self.rpc_router, device=self.device
          )
        if self.with_edge and data.edge_features is not None:
          self.dist_edge_feature = DistFeature(
            data.num_partitions, data.partition_idx,
            data.edge_features, data.edge_feat_pb,
            local_only=False, rpc_router=self.rpc_router, device=self.device
          )
      # dist_node_labels should is initialized as a DistFeature object in the v6d case
      self.dist_node_labels = self.data.node_labels
      if self.dist_graph.data_cls == 'homo':
        if self.dist_node_labels is not None and \
            not isinstance(self.dist_node_labels, torch.Tensor):
          self.dist_node_labels = DistFeature(
            self.data.num_partitions, self.data.partition_idx,
            self.dist_node_labels, self.data.node_feat_pb,
            local_only=False, rpc_router=self.rpc_router, device=self.device
          )
      else:
        assert self.dist_node_labels is None or isinstance(self.dist_node_labels, Dict)
        if self.dist_node_labels is not None and \
            all(isinstance(value, Feature) for value in self.dist_node_labels.values()):
          self.dist_node_labels = DistFeature(
            self.data.num_partitions, self.data.partition_idx,
            self.data.node_labels, self.data.node_feat_pb,
            local_only=False, rpc_router=self.rpc_router, device=self.device
          )
    else:
      raise ValueError(f"'{self.__class__.__name__}': found invalid input "
                       f"data type '{type(data)}'")

    self.sampler = NeighborSampler(
      self.dist_graph.local_graph, self.num_neighbors,
      self.device, self.with_edge, self.with_neg, self.with_weight, 
      self.edge_dir, seed=self.seed
    )
    self.inducer_pool = queue.Queue(maxsize=self.concurrency)

    # rpc register
    rpc_sample_callee = RpcSamplingCallee(self.sampler, self.device)
    self.rpc_sample_callee_id = rpc_register(rpc_sample_callee)
    rpc_subgraph_callee = RpcSubGraphCallee(self.sampler, self.device)
    self.rpc_subgraph_callee_id = rpc_register(rpc_subgraph_callee)

    if self.dist_graph.data_cls == 'hetero':
      self.num_neighbors = self.sampler.num_neighbors
      self.num_hops = self.sampler.num_hops
      self.edge_types = self.sampler.edge_types

    super().__init__(self.concurrency)
    self._loop.call_soon_threadsafe(ensure_device, self.device)