def __init__()

in graphlearn_torch/python/distributed/dist_feature.py [0:0]


  def __init__(self,
               num_partitions: int,
               partition_idx: int,
               local_feature: Union[Feature,
                                    Dict[Union[NodeType, EdgeType], Feature]],
               feature_pb: Union[PartitionBook,
                                 HeteroNodePartitionDict,
                                 HeteroEdgePartitionDict],
               local_only: bool = False,
               rpc_router: Optional[RpcDataPartitionRouter] = None,
               device: Optional[torch.device] = None):
    self.num_partitions = num_partitions
    self.partition_idx = partition_idx

    self.device = get_available_device(device)
    ensure_device(self.device)

    self.local_feature = local_feature
    if isinstance(self.local_feature, dict):
      self.data_cls = 'hetero'
      for _, feat in self.local_feature.items():
        if isinstance(feat, Feature):
          feat.lazy_init_with_ipc_handle()
    elif isinstance(self.local_feature, Feature):
      self.data_cls = 'homo'
      self.local_feature.lazy_init_with_ipc_handle()
    else:
      raise ValueError(f"'{self.__class__.__name__}': found invalid input "
                       f"feature type '{type(self.local_feature)}'")
    self.feature_pb = feature_pb
    if isinstance(self.feature_pb, dict):
      assert self.data_cls == 'hetero'
      for key, feat in self.feature_pb.items():
        if not isinstance(feat, PartitionBook):
          self.feature_pb[key] = GLTPartitionBook(feat)
    elif isinstance(self.feature_pb, PartitionBook):
      assert self.data_cls == 'homo'
    elif isinstance(self.feature_pb, torch.Tensor):
      assert self.data_cls == 'homo'
      self.feature_pb = GLTPartitionBook(self.feature_pb)
    else:
      raise ValueError(f"'{self.__class__.__name__}': found invalid input "
                       f"patition book type '{type(self.feature_pb)}'")

    self.rpc_router = rpc_router
    if not local_only:
      if self.rpc_router is None:
        raise ValueError(f"'{self.__class__.__name__}': a rpc router must be "
                         f"provided when `local_only` set to `False`")
      rpc_callee = RpcFeatureLookupCallee(self)
      self.rpc_callee_id = rpc_register(rpc_callee)
    else:
      self.rpc_callee_id = None