graphlearn_torch/python/distributed/dist_random_partitioner.py (396 lines of code) (raw):
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import threading
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from ..partition import (
save_meta, save_node_pb, save_edge_pb,
save_graph_partition, save_feature_partition,
PartitionBook
)
from ..typing import (
NodeType, EdgeType, TensorDataType,
GraphPartitionData, FeaturePartitionData
)
from ..utils import convert_to_tensor, ensure_dir, index_select
from .dist_context import get_context, init_worker_group
from .rpc import (
init_rpc, rpc_is_initialized, all_gather, barrier,
get_rpc_current_group_worker_names,
rpc_request_async, rpc_register, RpcCalleeBase
)
class RpcUpdatePartitionValueCallee(RpcCalleeBase):
def __init__(self, dist_partition_mgr):
super().__init__()
self.dist_partition_mgr = dist_partition_mgr
def call(self, *args, **kwargs):
self.dist_partition_mgr._update_part_val(*args, **kwargs)
return None
class RpcUpdatePartitionBookCallee(RpcCalleeBase):
def __init__(self, dist_partition_mgr):
super().__init__()
self.dist_partition_mgr = dist_partition_mgr
def call(self, *args, **kwargs):
self.dist_partition_mgr._update_pb(*args, **kwargs)
return None
class DistPartitionManager(object):
r""" A state manager for distributed partitioning.
"""
def __init__(self, total_val_size: int = 1, generate_pb: bool = True):
assert rpc_is_initialized() is True
self.num_parts = get_context().world_size
self.cur_pidx = get_context().rank
self._lock = threading.RLock()
self._worker_names = get_rpc_current_group_worker_names()
self.reset(total_val_size, generate_pb)
val_update_callee = RpcUpdatePartitionValueCallee(self)
self._val_update_callee_id = rpc_register(val_update_callee)
pb_update_callee = RpcUpdatePartitionBookCallee(self)
self._pb_update_callee_id = rpc_register(pb_update_callee)
def reset(self, total_val_size: int, generate_pb: bool = True):
with self._lock:
self.generate_pb = generate_pb
self.cur_part_val_list = []
if self.generate_pb:
self.partition_book = torch.zeros(total_val_size, dtype=torch.int64)
else:
self.partition_book = None
def process(self, res_list: List[Tuple[Any, torch.Tensor]]):
r""" Process partitioned results of the current corresponded distributed
partitioner and synchronize with others.
Args:
res_list: The result list of value and ids for each partition.
"""
assert len(res_list) == self.num_parts
futs = []
for pidx, (val, val_idx) in enumerate(res_list):
if pidx == self.cur_pidx:
self._update_part_val(val, pidx)
else:
futs.append(rpc_request_async(self._worker_names[pidx],
self._val_update_callee_id,
args=(val, pidx)))
if self.generate_pb:
futs.extend(self._broadcast_pb(val_idx, pidx))
_ = torch.futures.wait_all(futs)
def _broadcast_pb(self, val_idx: torch.Tensor, target_pidx: int):
pb_update_futs = []
for pidx in range(self.num_parts):
if pidx == self.cur_pidx:
self._update_pb(val_idx, target_pidx)
else:
pb_update_futs.append(rpc_request_async(self._worker_names[pidx],
self._pb_update_callee_id,
args=(val_idx, target_pidx)))
return pb_update_futs
def _update_part_val(self, val, target_pidx: int):
assert target_pidx == self.cur_pidx
with self._lock:
if val is not None:
self.cur_part_val_list.append(val)
def _update_pb(self, val_idx: torch.Tensor, target_pidx: int):
with self._lock:
self.partition_book[val_idx] = target_pidx
class DistRandomPartitioner(object):
r""" A distributed random partitioner for parallel partitioning with large
scale graph and features.
Each distributed partitioner will process a part of the full graph and
feature data, and partition them in parallel. A distributed partitioner's
rank is corresponding to a partition index, and the number of all distributed
partitioners must be same with the number of output partitions. During
partitioning, the partitioned results will be sent to other distributed
partitioners according to their ranks. After partitioning, each distributed
partitioner will own a partitioned graph with its corresponding rank and
further save the partitioned results into the local output directory.
Args:
output_dir: The output root directory on local machine for partitioned
results.
num_nodes: Number of all graph nodes, should be a dict for hetero data.
edge_index: A part of the edge index data of graph edges, should be a dict
for hetero data.
edge_ids: The edge ids of the input ``edge_index``.
node_feat: A part of the node feature data, should be a dict for hetero data.
node_feat_ids: The node ids corresponding to the input ``node_feat``.
edge_feat: The edge feature data, should be a dict for hetero data.
edge_feat_ids: The edge ids corresponding to the input ``edge_feat``.
num_parts: The number of all partitions. If not provided, the value of
``graphlearn_torch.distributed.get_context().world_size`` will be used.
current_partition_idx: The partition index corresponding to the current
distributed partitioner. If not provided, the value of
``graphlearn_torch.distributed.get_context().rank`` will be used.
node_feat_dtype: The data type of node features.
edge_feat_dtype: The data type of edge features.
edge_assign_strategy: The assignment strategy when partitioning edges,
should be 'by_src' or 'by_dst'.
chunk_size: The chunk size for partitioning.
master_addr: The master TCP address for RPC connection between all
distributed partitioners.
master_port: The master TCP port for RPC connection between all
distributed partitioners.
num_rpc_threads: The number of RPC worker threads to use.
"""
def __init__(
self,
output_dir: str,
num_nodes: Union[int, Dict[NodeType, int]],
edge_index: Union[TensorDataType, Dict[EdgeType, TensorDataType]],
edge_ids: Union[TensorDataType, Dict[EdgeType, TensorDataType]],
node_feat: Optional[Union[TensorDataType, Dict[NodeType, TensorDataType]]] = None,
node_feat_ids: Optional[Union[TensorDataType, Dict[NodeType, TensorDataType]]] = None,
edge_feat: Optional[Union[TensorDataType, Dict[EdgeType, TensorDataType]]] = None,
edge_feat_ids: Optional[Union[TensorDataType, Dict[EdgeType, TensorDataType]]] = None,
num_parts: Optional[int] = None,
current_partition_idx: Optional[int] = None,
node_feat_dtype: torch.dtype = torch.float32,
edge_feat_dtype: torch.dtype = torch.float32,
edge_assign_strategy: str = 'by_src',
chunk_size: int = 10000,
master_addr: Optional[str] = None,
master_port: Optional[str] = None,
num_rpc_threads: int = 16,
):
self.output_dir = output_dir
if get_context() is not None:
if num_parts is not None:
assert get_context().world_size == num_parts
if current_partition_idx is not None:
assert get_context().rank == current_partition_idx
else:
assert num_parts is not None
assert current_partition_idx is not None
init_worker_group(
world_size=num_parts,
rank=current_partition_idx,
group_name='distributed_random_partitoner'
)
self.num_parts = get_context().world_size
self.current_partition_idx = get_context().rank
if rpc_is_initialized() is not True:
if master_addr is None:
master_addr = os.environ['MASTER_ADDR']
if master_port is None:
master_port = int(os.environ['MASTER_PORT'])
init_rpc(master_addr, master_port, num_rpc_threads)
self.num_nodes = num_nodes
self.edge_index = convert_to_tensor(edge_index, dtype=torch.int64)
self.edge_ids = convert_to_tensor(edge_ids, dtype=torch.int64)
self.node_feat = convert_to_tensor(node_feat, dtype=node_feat_dtype)
self.node_feat_ids = convert_to_tensor(node_feat_ids, dtype=torch.int64)
if self.node_feat is not None:
assert self.node_feat_ids is not None
self.edge_feat = convert_to_tensor(edge_feat, dtype=edge_feat_dtype)
self.edge_feat_ids = convert_to_tensor(edge_feat_ids, dtype=torch.int64)
if self.edge_feat is not None:
assert self.edge_feat_ids is not None
if isinstance(self.num_nodes, dict):
assert isinstance(self.edge_index, dict)
assert isinstance(self.edge_ids, dict)
assert isinstance(self.node_feat, dict) or self.node_feat is None
assert isinstance(self.node_feat_ids, dict) or self.node_feat_ids is None
assert isinstance(self.edge_feat, dict) or self.edge_feat is None
assert isinstance(self.edge_feat_ids, dict) or self.edge_feat_ids is None
self.data_cls = 'hetero'
self.node_types = sorted(list(self.num_nodes.keys()))
self.edge_types = sorted(list(self.edge_index.keys()))
self.num_local_edges = {}
self.num_edges = {}
for etype, index in self.edge_index.items():
self.num_local_edges[etype] = len(index[0])
self.num_edges[etype] = sum(all_gather(len(index[0])).values())
else:
self.data_cls = 'homo'
self.node_types = None
self.edge_types = None
self.num_local_edges = len(self.edge_index[0])
self.num_edges = sum(all_gather(len(self.edge_index[0])).values())
self.edge_assign_strategy = edge_assign_strategy.lower()
assert self.edge_assign_strategy in ['by_src', 'by_dst']
self.chunk_size = chunk_size
self._partition_mgr = DistPartitionManager()
def _partition_by_chunk(
self,
val: Any,
val_idx: torch.Tensor,
partition_fn,
total_val_size: int,
generate_pb = True
):
r""" Partition generic values and sync with all other partitoners.
"""
val_num = len(val_idx)
chunk_num = (val_num + self.chunk_size - 1) // self.chunk_size
chunk_start_pos = 0
self._partition_mgr.reset(total_val_size, generate_pb)
barrier()
for _ in range(chunk_num):
chunk_end_pos = min(val_num, chunk_start_pos + self.chunk_size)
current_chunk_size = chunk_end_pos - chunk_start_pos
chunk_idx = torch.arange(current_chunk_size, dtype=torch.long)
chunk_val = index_select(val, index=(chunk_start_pos, chunk_end_pos))
chunk_val_idx = val_idx[chunk_start_pos:chunk_end_pos]
chunk_partition_idx = partition_fn(
chunk_val_idx, (chunk_start_pos, chunk_end_pos))
chunk_res = []
for pidx in range(self.num_parts):
mask = (chunk_partition_idx == pidx)
idx = torch.masked_select(chunk_idx, mask)
chunk_res.append((index_select(chunk_val, idx), chunk_val_idx[idx]))
self._partition_mgr.process(chunk_res)
chunk_start_pos += current_chunk_size
barrier()
return (
self._partition_mgr.cur_part_val_list,
self._partition_mgr.partition_book
)
def _partition_node(
self,
ntype: Optional[NodeType] = None
) -> PartitionBook:
r""" Partition graph nodes of a specify node type in parallel.
Args:
ntype (str): The type for input nodes, must be provided for heterogeneous
graph. (default: ``None``)
Returns:
PartitionBook: The partition book of graph nodes.
"""
if 'hetero' == self.data_cls:
assert ntype is not None
node_num = self.num_nodes[ntype]
else:
node_num = self.num_nodes
per_node_num = node_num // self.num_parts
local_node_start = per_node_num * self.current_partition_idx
local_node_end = min(
node_num,
per_node_num * (self.current_partition_idx + 1)
)
local_node_ids = torch.arange(
local_node_start, local_node_end, dtype=torch.int64
)
def _node_pfn(n_ids, _):
partition_idx = n_ids % self.num_parts
rand_order = torch.randperm(len(n_ids))
return partition_idx[rand_order]
_, node_pb = self._partition_by_chunk(
val=None,
val_idx=local_node_ids,
partition_fn=_node_pfn,
total_val_size=node_num,
generate_pb=True
)
return node_pb
def _partition_graph(
self,
node_pbs: Union[PartitionBook, Dict[NodeType, PartitionBook]],
etype: Optional[EdgeType] = None
) -> Tuple[GraphPartitionData, PartitionBook]:
r""" Partition graph topology of a specify edge type in parallel.
Args:
node_pbs (PartitionBook or Dict[NodeType, PartitionBook]): The
partition books of all graph nodes.
etype (Tuple[str, str, str]): The type for input edges, must be provided
for heterogeneous graph. (default: ``None``)
Returns:
GraphPartitionData: The graph data of the current partition.
PartitionBook: The partition book of graph edges.
"""
if 'hetero' == self.data_cls:
assert isinstance(node_pbs, dict)
assert etype is not None
src_ntype, _, dst_ntype = etype
edge_index = self.edge_index[etype]
rows, cols = edge_index[0], edge_index[1]
eids = self.edge_ids[etype]
edge_num = self.num_edges[etype]
if 'by_src' == self.edge_assign_strategy:
target_node_pb = node_pbs[src_ntype]
target_indices = rows
else:
target_node_pb = node_pbs[dst_ntype]
target_indices = cols
else:
edge_index = self.edge_index
rows, cols = edge_index[0], edge_index[1]
eids = self.edge_ids
edge_num = self.num_edges
target_node_pb = node_pbs
target_indices = rows if 'by_src' == self.edge_assign_strategy else cols
def _edge_pfn(_, chunk_range):
chunk_target_indices = index_select(target_indices, chunk_range)
return target_node_pb[chunk_target_indices]
res_list, edge_pb = self._partition_by_chunk(
val=(rows, cols, eids),
val_idx=eids,
partition_fn=_edge_pfn,
total_val_size=edge_num,
generate_pb=True
)
current_graph_part = GraphPartitionData(
edge_index=(
torch.cat([r[0] for r in res_list]),
torch.cat([r[1] for r in res_list]),
),
eids=torch.cat([r[2] for r in res_list])
)
return current_graph_part, edge_pb
def _partition_node_feat(
self,
node_pb: PartitionBook,
ntype: Optional[NodeType] = None,
) -> Optional[FeaturePartitionData]:
r""" Split node features in parallel by the partitioned node results,
and return the current partition of node features.
"""
if self.node_feat is None:
return None
if 'hetero' == self.data_cls:
assert ntype is not None
node_num = self.num_nodes[ntype]
node_feat = self.node_feat[ntype]
node_feat_ids = self.node_feat_ids[ntype]
else:
node_num = self.num_nodes
node_feat = self.node_feat
node_feat_ids = self.node_feat_ids
def _node_feat_pfn(nfeat_ids, _):
return node_pb[nfeat_ids]
res_list, _ = self._partition_by_chunk(
val=(node_feat, node_feat_ids),
val_idx=node_feat_ids,
partition_fn=_node_feat_pfn,
total_val_size=node_num,
generate_pb=False
)
return FeaturePartitionData(
feats=torch.cat([r[0] for r in res_list]),
ids=torch.cat([r[1] for r in res_list]),
cache_feats=None,
cache_ids=None
)
def _partition_edge_feat(
self,
edge_pb: PartitionBook,
etype: Optional[EdgeType] = None,
) -> Optional[FeaturePartitionData]:
r""" Split edge features in parallel by the partitioned edge results,
and return the current partition of edge features.
"""
if self.edge_feat is None:
return None
if 'hetero' == self.data_cls:
assert etype is not None
edge_num = self.num_edges[etype]
edge_feat = self.edge_feat[etype]
edge_feat_ids = self.edge_feat_ids[etype]
else:
edge_num = self.num_edges
edge_feat = self.edge_feat
edge_feat_ids = self.edge_feat_ids
def _edge_feat_pfn(efeat_ids, _):
return edge_pb[efeat_ids]
res_list, _ = self._partition_by_chunk(
val=(edge_feat, edge_feat_ids),
val_idx=edge_feat_ids,
partition_fn=_edge_feat_pfn,
total_val_size=edge_num,
generate_pb=False
)
return FeaturePartitionData(
feats=torch.cat([r[0] for r in res_list]),
ids=torch.cat([r[1] for r in res_list]),
cache_feats=None,
cache_ids=None
)
def partition(self):
r""" Partition graph and feature data into different parts along with all
other distributed partitioners, save the result of the current partition
index into output directory.
"""
ensure_dir(self.output_dir)
if 'hetero' == self.data_cls:
node_pb_dict = {}
for ntype in self.node_types:
node_pb = self._partition_node(ntype)
node_pb_dict[ntype] = node_pb
save_node_pb(self.output_dir, node_pb, ntype)
current_node_feat_part = self._partition_node_feat(node_pb, ntype)
if current_node_feat_part is not None:
save_feature_partition(
self.output_dir, self.current_partition_idx, current_node_feat_part,
group='node_feat', graph_type=ntype
)
del current_node_feat_part
for etype in self.edge_types:
current_graph_part, edge_pb = self._partition_graph(node_pb_dict, etype)
save_edge_pb(self.output_dir, edge_pb, etype)
save_graph_partition(
self.output_dir, self.current_partition_idx, current_graph_part, etype
)
del current_graph_part
current_edge_feat_part = self._partition_edge_feat(edge_pb, etype)
if current_edge_feat_part is not None:
save_feature_partition(
self.output_dir, self.current_partition_idx, current_edge_feat_part,
group='edge_feat', graph_type=etype
)
del current_edge_feat_part
else:
node_pb = self._partition_node()
save_node_pb(self.output_dir, node_pb)
current_node_feat_part = self._partition_node_feat(node_pb)
if current_node_feat_part is not None:
save_feature_partition(
self.output_dir, self.current_partition_idx,
current_node_feat_part, group='node_feat'
)
del current_node_feat_part
current_graph_part, edge_pb = self._partition_graph(node_pb)
save_edge_pb(self.output_dir, edge_pb)
save_graph_partition(
self.output_dir, self.current_partition_idx, current_graph_part
)
del current_graph_part
current_edge_feat_part = self._partition_edge_feat(edge_pb)
if current_edge_feat_part is not None:
save_feature_partition(
self.output_dir, self.current_partition_idx,
current_edge_feat_part, group='edge_feat'
)
del current_edge_feat_part
# save meta.
save_meta(self.output_dir, self.num_parts, self.data_cls,
self.node_types, self.edge_types)