in graphlearn_torch/python/distributed/dist_graph.py [0:0]
def __init__(self,
num_partitions: int,
partition_idx: int,
local_graph: Union[Graph, Dict[EdgeType, Graph]],
node_pb: Union[PartitionBook, HeteroNodePartitionDict],
edge_pb: Union[PartitionBook, HeteroEdgePartitionDict]=None):
self.num_partitions = num_partitions
self.partition_idx = partition_idx
self.local_graph = local_graph
if isinstance(self.local_graph, dict):
self.data_cls = 'hetero'
for _, graph in self.local_graph.items():
graph.lazy_init()
elif isinstance(self.local_graph, Graph):
self.data_cls = 'homo'
self.local_graph.lazy_init()
else:
raise ValueError(f"'{self.__class__.__name__}': found invalid input "
f"graph type '{type(self.local_graph)}'")
self.node_pb = node_pb
if self.node_pb is not None:
if isinstance(self.node_pb, dict):
assert self.data_cls == 'hetero'
for key, feat in self.node_pb.items():
if not isinstance(feat, PartitionBook):
self.node_pb[key] = GLTPartitionBook(feat)
elif isinstance(self.node_pb, PartitionBook):
assert self.data_cls == 'homo'
elif isinstance(self.node_pb, torch.Tensor):
assert self.data_cls == 'homo'
self.node_pb = GLTPartitionBook(self.node_pb)
else:
raise ValueError(f"'{self.__class__.__name__}': found invalid input "
f"node patition book type '{type(self.node_pb)}'")
self.edge_pb = edge_pb
if self.edge_pb is not None:
if isinstance(self.edge_pb, dict):
assert self.data_cls == 'hetero'
for key, feat in self.edge_pb.items():
if not isinstance(feat, PartitionBook):
self.edge_pb[key] = GLTPartitionBook(feat)
elif isinstance(self.edge_pb, PartitionBook):
assert self.data_cls == 'homo'
elif isinstance(self.edge_pb, torch.Tensor):
assert self.data_cls == 'homo'
self.edge_pb = GLTPartitionBook(self.edge_pb)
else:
raise ValueError(f"'{self.__class__.__name__}': found invalid input "
f"edge patition book type '{type(self.edge_pb)}'")