def __init__()

in graphlearn_torch/python/distributed/dist_graph.py [0:0]


  def __init__(self,
               num_partitions: int,
               partition_idx: int,
               local_graph: Union[Graph, Dict[EdgeType, Graph]],
               node_pb: Union[PartitionBook, HeteroNodePartitionDict],
               edge_pb: Union[PartitionBook, HeteroEdgePartitionDict]=None):
    self.num_partitions = num_partitions
    self.partition_idx = partition_idx
    self.local_graph = local_graph
    if isinstance(self.local_graph, dict):
      self.data_cls = 'hetero'
      for _, graph in self.local_graph.items():
        graph.lazy_init()
    elif isinstance(self.local_graph, Graph):
      self.data_cls = 'homo'
      self.local_graph.lazy_init()
    else:
      raise ValueError(f"'{self.__class__.__name__}': found invalid input "
                       f"graph type '{type(self.local_graph)}'")
    self.node_pb = node_pb
    if self.node_pb is not None:
      if isinstance(self.node_pb, dict):
        assert self.data_cls == 'hetero'
        for key, feat in self.node_pb.items():
          if not isinstance(feat, PartitionBook):
            self.node_pb[key] = GLTPartitionBook(feat)
      elif isinstance(self.node_pb, PartitionBook):
        assert self.data_cls == 'homo'
      elif isinstance(self.node_pb, torch.Tensor):
        assert self.data_cls == 'homo'
        self.node_pb = GLTPartitionBook(self.node_pb)
      else:
        raise ValueError(f"'{self.__class__.__name__}': found invalid input "
                        f"node patition book type '{type(self.node_pb)}'")
    self.edge_pb = edge_pb
    if self.edge_pb is not None:
      if isinstance(self.edge_pb, dict):
        assert self.data_cls == 'hetero'
        for key, feat in self.edge_pb.items():
          if not isinstance(feat, PartitionBook):
            self.edge_pb[key] = GLTPartitionBook(feat)
      elif isinstance(self.edge_pb, PartitionBook):
        assert self.data_cls == 'homo'
      elif isinstance(self.edge_pb, torch.Tensor):
        assert self.data_cls == 'homo'
        self.edge_pb = GLTPartitionBook(self.edge_pb)
      else:
        raise ValueError(f"'{self.__class__.__name__}': found invalid input "
                        f"edge patition book type '{type(self.edge_pb)}'")