graphlearn_torch/python/typing.py (46 lines of code) (raw):
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
import torch
import numpy as np
from enum import Enum
# Types for basic graph entity #################################################
# Node-types are denoted by a single string
NodeType = str
# Edge-types are denotes by a triplet of strings.
EdgeType = Tuple[str, str, str]
EDGE_TYPE_STR_SPLIT = '__'
def as_str(type: Union[NodeType, EdgeType]) -> str:
if isinstance(type, NodeType):
return type
elif isinstance(type, (list, tuple)) and len(type) == 3:
return EDGE_TYPE_STR_SPLIT.join(type)
return ''
def reverse_edge_type(etype: EdgeType):
src, edge, dst = etype
if not src == dst:
if edge.split("_", 1)[0] == 'rev': # undirected edge with `rev_` prefix.
edge = edge.split("_", 1)[1]
else:
edge = 'rev_' + edge
return (dst, edge, src)
# A representation of tensor data
TensorDataType = Union[torch.Tensor, np.ndarray]
NodeLabel = Union[TensorDataType, Dict[NodeType, TensorDataType]]
NodeIndex = Union[TensorDataType, Dict[NodeType, TensorDataType]]
class Split(Enum):
train = 'train'
valid = 'valid'
test = 'test'
# Types for partition data #####################################################
class GraphPartitionData(NamedTuple):
r""" Data and indexing info of a graph partition.
"""
# edge index (rows, cols)
edge_index: Tuple[torch.Tensor, torch.Tensor]
# edge ids tensor corresponding to `edge_index`
eids: torch.Tensor
# weights tensor corresponding to `edge_index`
weights: Optional[torch.Tensor] = None
class FeaturePartitionData(NamedTuple):
r""" Data and indexing info of a node/edge feature partition.
"""
# node/edge feature tensor
feats: Optional[torch.Tensor]
# node/edge ids tensor corresponding to `feats`
ids: Optional[torch.Tensor]
# feature cache tensor
cache_feats: Optional[torch.Tensor]
# cached node/edge ids tensor corresponding to `cache_feats`
cache_ids: Optional[torch.Tensor]
HeteroGraphPartitionData = Dict[EdgeType, GraphPartitionData]
HeteroFeaturePartitionData = Dict[Union[NodeType, EdgeType], FeaturePartitionData]
# Types for neighbor sampling ##################################################
Seeds = Union[torch.Tensor, str]
InputNodes = Union[Seeds, NodeType, Tuple[NodeType, Seeds], Tuple[NodeType, List[Seeds]]]
EdgeIndexTensor = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
InputEdges = Union[EdgeIndexTensor, EdgeType, Tuple[EdgeType, EdgeIndexTensor]]
NumNeighbors = Union[List[int], Dict[EdgeType, List[int]]]