graphlearn_torch/python/data/vineyard_utils.py (94 lines of code) (raw):

# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== try: import torch from typing import Dict from collections.abc import Sequence from .. import py_graphlearn_torch_vineyard as pywrap except ImportError: pass from ..partition import PartitionBook def vineyard_to_csr(sock, fid, v_label_name, e_label_name, edge_dir, haseid=0): ''' Wrap to_csr function to read graph from vineyard with return (indptr, indices, (Optional)edge_id) ''' return pywrap.vineyard_to_csr(sock, fid, v_label_name, e_label_name, edge_dir, haseid) def load_vertex_feature_from_vineyard(sock, fid, vcols, v_label_name): ''' Wrap load_vertex_feature_from_vineyard function to read vertex feature from vineyard return vertex_feature(torch.Tensor) ''' return pywrap.load_vertex_feature_from_vineyard(sock, fid, v_label_name, vcols) def load_edge_feature_from_vineyard(sock, fid, ecols, e_label_name): ''' Wrap load_edge_feature_from_vineyard function to read edge feature from vineyard return edge_feature(torch.Tensor) ''' return pywrap.load_edge_feature_from_vineyard(sock, fid, e_label_name, ecols) def get_fid_from_gid(gid): ''' Wrap get_fid_from_gid function to get fid from gid ''' return pywrap.get_fid_from_gid(gid) def get_frag_vertex_offset(sock, fid, v_label_name): ''' Wrap GetFragVertexOffset function to get vertex offset of a fragment. ''' return pywrap.get_frag_vertex_offset(sock, fid, v_label_name) def get_frag_vertex_num(sock, fid, v_label_name): ''' Wrap GetFragVertexNum function to get vertex number of a fragment. ''' return pywrap.get_frag_vertex_num(sock, fid, v_label_name) class VineyardPartitionBook(PartitionBook): def __init__(self, sock, obj_id, v_label_name, fid2pid: Dict=None): self._sock = sock self._obj_id = obj_id self._v_label_name = v_label_name self._frag = None self._offset = get_frag_vertex_offset(sock, obj_id, v_label_name) # TODO: optimise this query process if too slow self._fid2pid = fid2pid def __getitem__(self, gids) -> torch.Tensor: fids = self.gid2fid(gids) if self._fid2pid is not None: pids = torch.tensor([self._fid2pid[fid] for fid in fids]) return pids.to(torch.int32) return fids.to(torch.int32) @property def device(self): return torch.device('cpu') @property def offset(self): return self._offset def gid2fid(self, gids): ''' Parse gid to get fid ''' if self._frag is None: self._frag = pywrap.VineyardFragHandle(self._sock, self._obj_id) fids = self._frag.get_fid_from_gid(gids.tolist()) return fids class VineyardGid2Lid(Sequence): def __init__(self, sock, fid, v_label_name): self._offset = get_frag_vertex_offset(sock, fid, v_label_name) self._vnum = get_frag_vertex_num(sock, fid, v_label_name) def __getitem__(self, gids): return gids - self._offset def __len__(self): return self._vnum def v6d_id_select(srcs, p_mask, node_pb: PartitionBook): ''' Select the inner vertices in `srcs` that belong to a specific partition, and return their local offsets in the partition. ''' gids = torch.masked_select(srcs, p_mask) offsets = gids - node_pb.offset return offsets def v6d_id_filter(node_pb: VineyardPartitionBook, partition_idx): ''' Select the inner vertices that belong to a specific partition ''' frag = pywrap.VineyardFragHandle(node_pb._sock, node_pb._obj_id) inner_vertices = frag.get_inner_vertices(node_pb._v_label_name) return inner_vertices