graphlearn_torch/python/distributed/dist_feature.py (352 lines of code) (raw):

# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from typing import Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist from ..data import Feature from ..typing import ( EdgeType, NodeType, ) from ..sampler import ( SamplerOutput, HeteroSamplerOutput, ) from ..partition import ( PartitionBook, GLTPartitionBook, HeteroNodePartitionDict, HeteroEdgePartitionDict ) from ..utils import get_available_device, ensure_device from .rpc import ( RpcDataPartitionRouter, RpcCalleeBase, rpc_register, rpc_request_async ) # Given a set of node ids, the `PartialFeature` stores the feature info # of a subset of the original ids, the first tensor is the features of the # subset node ids, and the second tensor records the index of the subset # node ids. PartialFeature = Tuple[torch.Tensor, torch.Tensor] def communicate_node_num(send_tensor): if not torch.is_tensor(send_tensor): send_tensor = torch.tensor(send_tensor, dtype=torch.int64) recv_tensor = torch.zeros(send_tensor.shape[0], dtype=torch.int64) else: recv_tensor = torch.zeros(send_tensor.shape[0], dtype=send_tensor.dtype) scount = [1 for i in range(send_tensor.shape[0])] rcount = [1 for i in range(send_tensor.shape[0])] sync_req = dist.all_to_all_single(recv_tensor, send_tensor, rcount, scount, async_op=True) sync_req.wait() dist.barrier() return send_tensor, recv_tensor class RpcFeatureLookupCallee(RpcCalleeBase): r""" A wrapper for rpc callee that will perform feature lookup from remote processes. """ def __init__(self, dist_feature): super().__init__() self.dist_feature = dist_feature def call(self, *args, **kwargs): return self.dist_feature.local_get(*args, **kwargs) class DistFeature(object): r""" Distributed feature data manager for global feature lookups. Args: num_partitions: Number of data partitions. partition_id: Data partition idx of current process. local_feature: Local ``Feature`` instance. feature_pb: Partition book which records node/edge ids to worker node ids mapping on feature store. local_only: Use this instance only for local feature lookup or stitching. If set to ``True``, the related rpc callee will not be registered and users should ensure that lookups for remote features are not invoked through this instance. Default to ``False``. device: Device used for computing. Default to ``None``. Note that`local_feature` and `feature_pb` should be a dictionary for hetero data. """ def __init__(self, num_partitions: int, partition_idx: int, local_feature: Union[Feature, Dict[Union[NodeType, EdgeType], Feature]], feature_pb: Union[PartitionBook, HeteroNodePartitionDict, HeteroEdgePartitionDict], local_only: bool = False, rpc_router: Optional[RpcDataPartitionRouter] = None, device: Optional[torch.device] = None): self.num_partitions = num_partitions self.partition_idx = partition_idx self.device = get_available_device(device) ensure_device(self.device) self.local_feature = local_feature if isinstance(self.local_feature, dict): self.data_cls = 'hetero' for _, feat in self.local_feature.items(): if isinstance(feat, Feature): feat.lazy_init_with_ipc_handle() elif isinstance(self.local_feature, Feature): self.data_cls = 'homo' self.local_feature.lazy_init_with_ipc_handle() else: raise ValueError(f"'{self.__class__.__name__}': found invalid input " f"feature type '{type(self.local_feature)}'") self.feature_pb = feature_pb if isinstance(self.feature_pb, dict): assert self.data_cls == 'hetero' for key, feat in self.feature_pb.items(): if not isinstance(feat, PartitionBook): self.feature_pb[key] = GLTPartitionBook(feat) elif isinstance(self.feature_pb, PartitionBook): assert self.data_cls == 'homo' elif isinstance(self.feature_pb, torch.Tensor): assert self.data_cls == 'homo' self.feature_pb = GLTPartitionBook(self.feature_pb) else: raise ValueError(f"'{self.__class__.__name__}': found invalid input " f"patition book type '{type(self.feature_pb)}'") self.rpc_router = rpc_router if not local_only: if self.rpc_router is None: raise ValueError(f"'{self.__class__.__name__}': a rpc router must be " f"provided when `local_only` set to `False`") rpc_callee = RpcFeatureLookupCallee(self) self.rpc_callee_id = rpc_register(rpc_callee) else: self.rpc_callee_id = None def _get_local_store(self, input_type: Optional[Union[NodeType, EdgeType]]): if self.data_cls == 'hetero': assert input_type is not None return self.local_feature[input_type], self.feature_pb[input_type] return self.local_feature, self.feature_pb def local_get( self, ids: torch.Tensor, input_type: Optional[Union[NodeType, EdgeType]] = None ) -> torch.Tensor: r""" Lookup features in the local feature store, the input node/edge ids should be guaranteed to be all local to the current feature store. """ feat, _ = self._get_local_store(input_type) # TODO: check performance with `return feat[ids].cpu()` return feat.cpu_get(ids) def get_all2all ( self, sampler_result: Union[SamplerOutput, HeteroSamplerOutput], ntype_list: List[NodeType] ) -> Dict[NodeType, torch.tensor]: r""" Lookup features synchronously using torch.distributed.all_to_all. """ remote_feats_dict = self.remote_selecting_get_all2all(sampler_result, ntype_list) feat_dict = {} for ntype, nodes in sampler_result.node.items(): nodes = nodes.to(torch.long) local_feat = self._local_selecting_get(nodes, ntype) remote_feats = remote_feats_dict.get(ntype, None) feat_dict[ntype] = self._stitch(nodes, local_feat, remote_feats) return feat_dict def async_get( self, ids: torch.Tensor, input_type: Optional[Union[NodeType, EdgeType]] = None ) -> torch.futures.Future: r""" Lookup features asynchronously and return a future. """ remote_fut = self._remote_selecting_get(ids, input_type) local_feature = self._local_selecting_get(ids, input_type) res_fut = torch.futures.Future() def on_done(*_): try: remote_feature_list = remote_fut.wait() result = self._stitch(ids, local_feature, remote_feature_list) except Exception as e: res_fut.set_exception(e) else: res_fut.set_result(result) remote_fut.add_done_callback(on_done) return res_fut def __getitem__( self, input: Union[torch.Tensor, Tuple[Union[NodeType, EdgeType], torch.Tensor]] ) -> torch.Tensor: r""" Lookup features synchronously in a '__getitem__' way. """ if isinstance(input, torch.Tensor): input_type, ids = None, input elif isinstance(input, tuple): input_type, ids = ids[0], ids[1] else: raise ValueError(f"'{self.__class__.__name__}': found invalid input " f"type for feature lookup: '{type(input)}'") fut = self.async_get(ids, input_type) return fut.wait() def _local_selecting_get( self, ids: torch.Tensor, input_type: Optional[Union[NodeType, EdgeType]] = None ) -> torch.Tensor: r""" Select node/edge ids only in the local feature store and lookup features of them. Args: ids: input node/edge ids. input_type: input node/edge type for heterogeneous feature lookup. Return: PartialFeature: features and index for local node/edge ids. """ feat, pb = self._get_local_store(input_type) input_order= torch.arange(ids.size(0), dtype=torch.long, device=self.device) partition_ids = pb[ids.to(pb.device)].to(self.device) ids = ids.to(self.device) local_mask = (partition_ids == self.partition_idx) local_ids = torch.masked_select(ids, local_mask) local_index = torch.masked_select(input_order, local_mask) return feat[local_ids], local_index def remote_selecting_prepare( self, sampler_result: Union[SamplerOutput, HeteroSamplerOutput], ntype_list: List[NodeType] ): rfeat_recv_dict = {} rfeat_send_dict = {} for ntype in ntype_list: ids = sampler_result.node.get(ntype, None) if ids is None: send_remote_count = torch.zeros(self.num_partitions, dtype=torch.int64) else: ids = ids.to(torch.long) _, pb = self._get_local_store(ntype) ids = ids.to(self.device) partition_ids = pb[ids.to(pb.device)].to(self.device) send_remote_count = [] for pidx in range(0, self.num_partitions): if pidx == self.partition_idx: send_remote_count.append(0) else: remote_mask = (partition_ids == pidx) remote_ids = torch.masked_select(ids, remote_mask) ssize = remote_ids.numel() send_remote_count.append(ssize) send_sr, recv_sr = communicate_node_num(send_remote_count) rfeat_recv_dict[ntype] = recv_sr rfeat_send_dict[ntype] = send_sr return rfeat_send_dict, rfeat_recv_dict def communicate_node_id ( self, sampler_result: Union[SamplerOutput, HeteroSamplerOutput], ntype_list: List[NodeType] ): offset = 0 indexes = {} send_ids = [] remote_cnt_list = torch.zeros(self.num_partitions, dtype=torch.long) for ntype in ntype_list: indexes[ntype] = [None] * self.num_partitions for pidx in range(0, self.num_partitions): remote_cnt_sum = 0 for ntype in ntype_list: nodes = sampler_result.node.get(ntype, None) if nodes is None: continue nodes = nodes.to(torch.long) _, pb = self._get_local_store(ntype) input_order= torch.arange(nodes.size(0), dtype=torch.long, device=self.device) partition_ids = pb[nodes.to(pb.device)].to(self.device) nodes = nodes.to(self.device) if pidx == self.partition_idx: continue else: remote_mask = (partition_ids == pidx) remote_ids = torch.masked_select(nodes, remote_mask) indexes[ntype][pidx] = torch.masked_select(input_order, remote_mask) ssize = remote_ids.numel() send_ids[offset: offset + ssize] = remote_ids.tolist() remote_cnt_sum = remote_cnt_sum + remote_ids.numel() offset = offset + ssize remote_cnt_list[pidx] = remote_cnt_sum assert len(send_ids) == sum(remote_cnt_list) send_sr, recv_sr = communicate_node_num(remote_cnt_list) _, trecv = sum(send_sr), sum(recv_sr) self.recv_rn_count = [] for pidx in range(self.num_partitions): self.recv_rn_count.append(int(recv_sr[pidx])) self.recv_rn_gnid = torch.zeros(trecv, dtype=torch.long) dist.all_to_all_single(self.recv_rn_gnid, torch.tensor(send_ids), self.recv_rn_count, remote_cnt_list.tolist(), async_op=False) return remote_cnt_list, indexes def communicate_node_feats( self, ntype_list: List[NodeType], remote_cnt: torch.Tensor, send_num_dict: Dict[NodeType, List[int]], recv_num_dict: Dict[NodeType, List[int]], indexes: Dict[NodeType, List] ): rfeats_list = [] offset = 0 for pidx in range(self.num_partitions): if pidx == self.partition_idx: continue else: for ntype in ntype_list: feat_num = recv_num_dict.get(ntype)[pidx] if feat_num > 0: feat, _ = self._get_local_store(ntype) ntype_ids = self.recv_rn_gnid[offset:offset+feat_num] offset = offset + feat_num rfeats_list.append(feat[ntype_ids]) rfeats_send = torch.cat(rfeats_list, dim=0) feat_size = rfeats_send.shape[1] send_count = self.recv_rn_count recv_count = remote_cnt.tolist() recv_feats = torch.zeros((sum(recv_count), feat_size), dtype=rfeats_send.dtype) req = dist.all_to_all_single(recv_feats, rfeats_send, recv_count, send_count, async_op=True) req.wait() dist.barrier() recv_feat_list = torch.split(recv_feats, recv_count, dim = 0) remote_feats_dict = {} for ntype in ntype_list: remote_feats_dict[ntype] = [] for pidx in range(self.num_partitions): if pidx == self.partition_idx: continue else: offset = 0 for ntype in ntype_list: send_num = send_num_dict.get(ntype)[pidx] if send_num > 0: ntype_feat = recv_feat_list[pidx][offset:offset+send_num, :] remote_feats_dict[ntype].append((ntype_feat, indexes[ntype][pidx])) offset = offset + send_num return remote_feats_dict def remote_selecting_get_all2all( self, sampler_result: Union[SamplerOutput, HeteroSamplerOutput], ntype_list: List[NodeType] ) -> Dict[NodeType, List]: rfeat_send_dict, rfeat_recv_dict = self.remote_selecting_prepare(sampler_result, ntype_list) remote_cnt, indexes = self.communicate_node_id(sampler_result, ntype_list) dist.barrier() remote_feats_dict = self.communicate_node_feats(ntype_list, remote_cnt, rfeat_send_dict, rfeat_recv_dict, indexes) return remote_feats_dict def _remote_selecting_get( self, ids: torch.Tensor, input_type: Optional[Union[NodeType, EdgeType]] = None ) -> torch.futures.Future: r""" Select node/edge ids only in the remote feature stores and fetch their features. Args: ids: input node/edge ids. input_type: input node/edge type for heterogeneous feature lookup. Return: torch.futures.Future: a torch future with a list of `PartialFeature`, which corresponds to partial features on different remote workers. """ assert ( self.rpc_callee_id is not None ), "Remote feature lookup is disabled in 'local_only' mode." _, pb = self._get_local_store(input_type) ids = ids.to(pb.device) input_order= torch.arange(ids.size(0), dtype=torch.long) partition_ids = pb[ids].cpu() futs, indexes = [], [] for pidx in range(0, self.num_partitions): if pidx == self.partition_idx: continue remote_mask = (partition_ids == pidx) remote_ids = torch.masked_select(ids, remote_mask) if remote_ids.shape[0] > 0: to_worker = self.rpc_router.get_to_worker(pidx) futs.append(rpc_request_async(to_worker, self.rpc_callee_id, args=(remote_ids.cpu(), input_type))) indexes.append(torch.masked_select(input_order, remote_mask)) collect_fut = torch.futures.collect_all(futs) res_fut = torch.futures.Future() def on_done(*_): try: fut_list = collect_fut.wait() result = [] for i, fut in enumerate(fut_list): result.append((fut.wait(), indexes[i])) except Exception as e: res_fut.set_exception(e) else: res_fut.set_result(result) collect_fut.add_done_callback(on_done) return res_fut def _stitch( self, ids: torch.Tensor, local: PartialFeature, remotes: List[PartialFeature] ) -> torch.Tensor: r""" Stitch local and remote partial features into a complete one. Args: ids: the complete input node/edge ids. local: partial feature of local node/edge ids. remotes: partial feature list of remote node/edge ids. """ feat = torch.zeros(ids.shape[0], local[0].shape[1], dtype=local[0].dtype, device=self.device) feat[local[1].to(self.device)] = local[0].to(self.device) for remote in remotes: feat[remote[1].to(self.device)] = remote[0].to(self.device) return feat