graphlearn_torch/python/sampler/base.py (240 lines of code) (raw):

# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from abc import ABC, abstractmethod import math from dataclasses import dataclass from enum import Enum from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union, Literal import torch from ..typing import NodeType, EdgeType, NumNeighbors, Split from ..utils import CastMixin class EdgeIndex(NamedTuple): r""" PyG's :class:`~torch_geometric.loader.EdgeIndex` used in old data loader :class:`~torch_geometric.loader.NeighborSampler`: https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/loader/neighbor_sampler.py """ edge_index: torch.Tensor e_id: Optional[torch.Tensor] size: Tuple[int, int] def to(self, *args, **kwargs): edge_index = self.edge_index.to(*args, **kwargs) e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None return EdgeIndex(edge_index, e_id, self.size) @dataclass class NodeSamplerInput(CastMixin): r""" The sampling input of :meth:`~graphlearn_torch.sampler.BaseSampler.sample_from_nodes`. This class corresponds to :class:`~torch_geometric.sampler.NodeSamplerInput`: https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/sampler/base.py Args: node (torch.Tensor): The indices of seed nodes to start sampling from. input_type (str, optional): The input node type (in case of sampling in a heterogeneous graph). (default: :obj:`None`). """ node: torch.Tensor input_type: Optional[NodeType] = None def __getitem__(self, index: Union[torch.Tensor, Any]) -> 'NodeSamplerInput': if not isinstance(index, torch.Tensor): index = torch.tensor(index, dtype=torch.long) index = index.to(self.node.device) return NodeSamplerInput(self.node[index], self.input_type) def __len__(self): return self.node.numel() def share_memory(self): self.node.share_memory_() return self def to(self, device: torch.device): self.node.to(device) return self class NegativeSamplingMode(Enum): # 'binary': Randomly sample negative edges in the graph. binary = 'binary' # 'triplet': Randomly sample negative destination nodes for each positive # source node. triplet = 'triplet' @dataclass class NegativeSampling(CastMixin): r"""The negative sampling configuration of a :class:`~torch_geometric.sampler.BaseSampler` when calling :meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges`. Args: mode (str): The negative sampling mode (:obj:`"binary"` or :obj:`"triplet"`). If set to :obj:`"binary"`, will randomly sample negative links from the graph. If set to :obj:`"triplet"`, will randomly sample negative destination nodes for each positive source node. amount (int or float, optional): The ratio of sampled negative edges to the number of positive edges. (default: :obj:`1`) weight (torch.Tensor, optional): A node-level vector determining the sampling of nodes. Does not necessariyl need to sum up to one. If not given, negative nodes will be sampled uniformly. (default: :obj:`None`) """ mode: NegativeSamplingMode amount: Union[int, float] = 1 weight: Optional[torch.Tensor] = None def __init__( self, mode: Union[NegativeSamplingMode, str], amount: Union[int, float] = 1, weight: Optional[torch.Tensor] = None, ): self.mode = NegativeSamplingMode(mode) self.amount = amount self.weight = weight if self.amount <= 0: raise ValueError(f"The attribute 'amount' needs to be positive " f"for '{self.__class__.__name__}' " f"(got {self.amount})") if self.is_triplet(): if self.amount != math.ceil(self.amount): raise ValueError(f"The attribute 'amount' needs to be an " f"integer for '{self.__class__.__name__}' " f"with 'triplet' negative sampling " f"(got {self.amount}).") self.amount = math.ceil(self.amount) def is_binary(self) -> bool: return self.mode == NegativeSamplingMode.binary def is_triplet(self) -> bool: return self.mode == NegativeSamplingMode.triplet def share_memory(self): if self.weight is not None: self.weight.share_memory_() return self def to(self, device: torch.device): if self.weight is not None: self.weight.to(device) return self @dataclass class EdgeSamplerInput(CastMixin): r""" The sampling input of :meth:`~graphlearn_torch.sampler.BaseSampler.sample_from_edges`. This class corresponds to :class:`~torch_geometric.sampler.EdgeSamplerInput`: https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/sampler/base.py Args: row (torch.Tensor): The source node indices of seed links to start sampling from. col (torch.Tensor): The destination node indices of seed links to start sampling from. label (torch.Tensor, optional): The label for the seed links. (default: :obj:`None`). input_type (Tuple[str, str, str], optional): The input edge type (in case of sampling in a heterogeneous graph). (default: :obj:`None`). """ row: torch.Tensor col: torch.Tensor label: Optional[torch.Tensor] = None input_type: Optional[EdgeType] = None neg_sampling: Optional[NegativeSampling] = None def __getitem__(self, index: Union[torch.Tensor, Any]) -> 'EdgeSamplerInput': if not isinstance(index, torch.Tensor): index = torch.tensor(index, dtype=torch.long) index = index.to(self.row.device) return EdgeSamplerInput( self.row[index], self.col[index], self.label[index] if self.label is not None else None, self.input_type, self.neg_sampling ) def __len__(self): return self.row.numel() def share_memory(self): self.row.share_memory_() self.col.share_memory_() if self.label is not None: self.label.share_memory_() if self.label is not None: self.neg_sampling.share_memory() return self def to(self, device: torch.device): self.row.to(device) self.col.to(device) if self.label is not None: self.label.to(device) if self.label is not None: self.neg_sampling.to(device) return self @dataclass class SamplerOutput(CastMixin): r""" The sampling output of a :class:`~graphlearn_torch.sampler.BaseSampler` on homogeneous graphs. This class corresponds to :class:`~torch_geometric.sampler.SamplerOutput`: https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/sampler/base.py Args: node (torch.Tensor): The sampled nodes in the original graph. row (torch.Tensor): The source node indices of the sampled subgraph. Indices must be re-indexed to :obj:`{ 0, ..., num_nodes - 1 }` corresponding to the nodes in the :obj:`node` tensor. col (torch.Tensor): The destination node indices of the sampled subgraph. Indices must be re-indexed to :obj:`{ 0, ..., num_nodes - 1 }` corresponding to the nodes in the :obj:`node` tensor. edge (torch.Tensor, optional): The sampled edges in the original graph. This tensor is used to obtain edge features from the original graph. If no edge attributes are present, it may be omitted. batch (torch.Tensor, optional): The vector to identify the seed node for each sampled node. Can be present in case of disjoint subgraph sampling per seed node. (default: :obj:`None`). device (torch.device, optional): The device that all data of this output resides in. (default: :obj:`None`). metadata: (Any, optional): Additional metadata information. (default: :obj:`None`). """ node: torch.Tensor row: torch.Tensor col: torch.Tensor edge: Optional[torch.Tensor] = None batch: Optional[torch.Tensor] = None num_sampled_nodes: Optional[Union[List[int], torch.Tensor]] = None num_sampled_edges: Optional[Union[List[int], torch.Tensor]] = None device: Optional[torch.device] = None metadata: Optional[Any] = None @dataclass class HeteroSamplerOutput(CastMixin): r""" The sampling output of a :class:`~graphlearn_torch.sampler.BaseSampler` on heterogeneous graphs. This class corresponds to :class:`~torch_geometric.sampler.HeteroSamplerOutput`: https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/sampler/base.py Args: node (Dict[str, torch.Tensor]): The sampled nodes in the original graph for each node type. row (Dict[Tuple[str, str, str], torch.Tensor]): The source node indices of the sampled subgraph for each edge type. Indices must be re-indexed to :obj:`{ 0, ..., num_nodes - 1 }` corresponding to the nodes in the :obj:`node` tensor of the source node type. col (Dict[Tuple[str, str, str], torch.Tensor]): The destination node indices of the sampled subgraph for each edge type. Indices must be re-indexed to :obj:`{ 0, ..., num_nodes - 1 }` corresponding to the nodes in the :obj:`node` tensor of the destination node type. edge (Dict[Tuple[str, str, str], torch.Tensor], optional): The sampled edges in the original graph for each edge type. This tensor is used to obtain edge features from the original graph. If no edge attributes are present, it may be omitted. (default: :obj:`None`). batch (Dict[str, torch.Tensor], optional): The vector to identify the seed node for each sampled node for each node type. Can be present in case of disjoint subgraph sampling per seed node. (default: :obj:`None`). edge_types: (List[Tuple[str, str, str]], optional): The list of edge types of the sampled subgraph. (default: :obj:`None`). input_type: (Union[NodeType, EdgeType], optional): The input type of seed nodes or edge_label_index. (default: :obj:`None`). device (torch.device, optional): The device that all data of this output resides in. (default: :obj:`None`). metadata: (Any, optional): Additional metadata information. (default: :obj:`None`) """ node: Dict[NodeType, torch.Tensor] row: Dict[EdgeType, torch.Tensor] col: Dict[EdgeType, torch.Tensor] edge: Optional[Dict[EdgeType, torch.Tensor]] = None batch: Optional[Dict[NodeType, torch.Tensor]] = None num_sampled_nodes: Optional[Dict[NodeType, Union[List[int], torch.Tensor]]] = None num_sampled_edges: Optional[Dict[EdgeType, Union[List[int], torch.Tensor]]] = None edge_types: Optional[List[EdgeType]] = None input_type: Optional[Union[NodeType, EdgeType]] = None device: Optional[torch.device] = None metadata: Optional[Any] = None def get_edge_index(self): edge_index = {k: torch.stack([v, self.col[k]]) for k, v in self.row.items()} if self.edge_types is not None: for etype in self.edge_types: if edge_index.get(etype, None) is None: edge_index[etype] = \ torch.empty((2, 0), dtype=torch.long).to(self.device) return edge_index @dataclass class NeighborOutput(CastMixin): r""" The output of sampled neighbor results for a single hop sampling. Args: nbr (torch.Tensor): A 1D tensor of all sampled neighborhood node ids. nbr_num (torch.Tensor): A 1D tensor that identify the number of neighborhood nodes for each source nodes. Must be the same length as the source nodes of this sampling hop. nbr_num (torch.Tensor, optional): The edge ids corresponding to the sampled edges (from source node to the sampled neighborhood node). Should be the same length as :obj:`nbr` if provided. """ nbr: torch.Tensor nbr_num: torch.Tensor edge: Optional[torch.Tensor] def to(self, device: torch.device): return NeighborOutput( nbr=self.nbr.to(device), nbr_num=self.nbr_num.to(device), edge=(self.edge.to(device) if self.edge is not None else None) ) class SamplingType(Enum): r""" Enum class for sampling types. """ NODE = 0 LINK = 1 SUBGRAPH = 2 RANDOM_WALK = 3 @dataclass class SamplingConfig: r""" Configuration info for sampling. """ sampling_type: SamplingType num_neighbors: Optional[NumNeighbors] batch_size: int shuffle: bool drop_last: bool with_edge: bool collect_features: bool with_neg: bool with_weight: bool edge_dir: Literal['in', 'out'] seed: int class BaseSampler(ABC): r""" A base class that initializes a graph sampler and provides :meth:`sample_from_nodes` and :meth:`sample_from_edges` routines. This class corresponds to :class:`~torch_geometric.sampler.BaseSampler`: https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/sampler/base.py """ @abstractmethod def sample_from_nodes( self, inputs: NodeSamplerInput, **kwargs ) -> Union[HeteroSamplerOutput, SamplerOutput]: r""" Performs sampling from the nodes specified in :obj:`inputs`, returning a sampled subgraph(egograph) in the specified output format. Args: inputs (torch.Tensor): The input data with node indices to start sampling from. """ @abstractmethod def sample_from_edges( self, inputs: EdgeSamplerInput, **kwargs, ) -> Union[HeteroSamplerOutput, SamplerOutput]: r""" Performs sampling from the edges specified in :obj:`inputs`, returning a sampled subgraph(egograph) in the specified output format. Args: inputs (EdgeSamplerInput): The input data for sampling from edges including the (1) source node indices, the (2) destination node indices, the (3) optional edge labels and the (4) input edge type. """ @abstractmethod def subgraph( self, inputs: NodeSamplerInput, ) -> SamplerOutput: r""" Induce an enclosing subgraph based on inputs and their neighbors(if num_neighbors is not None). Args: inputs (torch.Tensor): The input data with node indices to induce subgraph from. Returns: The sampled unique nodes, relabeled rows and cols, original edge_ids, and a mapping from indices in `inputs` to new indices in output nodes, i.e. nodes[mapping] = inputs. """ class RemoteSamplerInput(ABC): """A base class that provides the `to_local_sampler_input` method for the server to obtain the sampler input. """ @abstractmethod def to_local_sampler_input( self, dataset, **kwargs ) -> Union[NodeSamplerInput, EdgeSamplerInput]: r""" Abstract method to convert the sampler input to local format. """ class RemoteNodePathSamplerInput(RemoteSamplerInput): r"""RemoteNodePathSamplerInput passes the node path to the server, where the server can load node seeds from it. """ def __init__(self, node_path: str, input_type: str ) -> None: self.node_path = node_path self.input_type = input_type def to_local_sampler_input( self, dataset, **kwargs, ) -> NodeSamplerInput: node = torch.load(self.node_path) return NodeSamplerInput(node=node, input_type=self.input_type) class RemoteNodeSplitSamplerInput(RemoteSamplerInput): r"""RemoteNodeSplitSamplerInput passes the split category to the server and the server loads seeds from the dataset. """ def __init__(self, split: Split, input_type: str ) -> None: self.split = split self.input_type = input_type def to_local_sampler_input( self, dataset, **kwargs, ) -> NodeSamplerInput: if self.split == Split.train: idx = dataset.train_idx elif self.split == Split.valid: idx = dataset.val_idx elif self.split == Split.test: idx = dataset.test_idx if isinstance(idx, dict): idx = idx[self.input_type] return NodeSamplerInput(node=idx, input_type=self.input_type)