in graphlearn_torch/python/distributed/dist_random_partitioner.py [0:0]
def partition(self):
r""" Partition graph and feature data into different parts along with all
other distributed partitioners, save the result of the current partition
index into output directory.
"""
ensure_dir(self.output_dir)
if 'hetero' == self.data_cls:
node_pb_dict = {}
for ntype in self.node_types:
node_pb = self._partition_node(ntype)
node_pb_dict[ntype] = node_pb
save_node_pb(self.output_dir, node_pb, ntype)
current_node_feat_part = self._partition_node_feat(node_pb, ntype)
if current_node_feat_part is not None:
save_feature_partition(
self.output_dir, self.current_partition_idx, current_node_feat_part,
group='node_feat', graph_type=ntype
)
del current_node_feat_part
for etype in self.edge_types:
current_graph_part, edge_pb = self._partition_graph(node_pb_dict, etype)
save_edge_pb(self.output_dir, edge_pb, etype)
save_graph_partition(
self.output_dir, self.current_partition_idx, current_graph_part, etype
)
del current_graph_part
current_edge_feat_part = self._partition_edge_feat(edge_pb, etype)
if current_edge_feat_part is not None:
save_feature_partition(
self.output_dir, self.current_partition_idx, current_edge_feat_part,
group='edge_feat', graph_type=etype
)
del current_edge_feat_part
else:
node_pb = self._partition_node()
save_node_pb(self.output_dir, node_pb)
current_node_feat_part = self._partition_node_feat(node_pb)
if current_node_feat_part is not None:
save_feature_partition(
self.output_dir, self.current_partition_idx,
current_node_feat_part, group='node_feat'
)
del current_node_feat_part
current_graph_part, edge_pb = self._partition_graph(node_pb)
save_edge_pb(self.output_dir, edge_pb)
save_graph_partition(
self.output_dir, self.current_partition_idx, current_graph_part
)
del current_graph_part
current_edge_feat_part = self._partition_edge_feat(edge_pb)
if current_edge_feat_part is not None:
save_feature_partition(
self.output_dir, self.current_partition_idx,
current_edge_feat_part, group='edge_feat'
)
del current_edge_feat_part
# save meta.
save_meta(self.output_dir, self.num_parts, self.data_cls,
self.node_types, self.edge_types)