graphlearn_torch/python/sampler/base.py (240 lines of code) (raw):
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from abc import ABC, abstractmethod
import math
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union, Literal
import torch
from ..typing import NodeType, EdgeType, NumNeighbors, Split
from ..utils import CastMixin
class EdgeIndex(NamedTuple):
r""" PyG's :class:`~torch_geometric.loader.EdgeIndex` used in old data loader
:class:`~torch_geometric.loader.NeighborSampler`:
https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/loader/neighbor_sampler.py
"""
edge_index: torch.Tensor
e_id: Optional[torch.Tensor]
size: Tuple[int, int]
def to(self, *args, **kwargs):
edge_index = self.edge_index.to(*args, **kwargs)
e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None
return EdgeIndex(edge_index, e_id, self.size)
@dataclass
class NodeSamplerInput(CastMixin):
r""" The sampling input of
:meth:`~graphlearn_torch.sampler.BaseSampler.sample_from_nodes`.
This class corresponds to :class:`~torch_geometric.sampler.NodeSamplerInput`:
https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/sampler/base.py
Args:
node (torch.Tensor): The indices of seed nodes to start sampling from.
input_type (str, optional): The input node type (in case of sampling in
a heterogeneous graph). (default: :obj:`None`).
"""
node: torch.Tensor
input_type: Optional[NodeType] = None
def __getitem__(self, index: Union[torch.Tensor, Any]) -> 'NodeSamplerInput':
if not isinstance(index, torch.Tensor):
index = torch.tensor(index, dtype=torch.long)
index = index.to(self.node.device)
return NodeSamplerInput(self.node[index], self.input_type)
def __len__(self):
return self.node.numel()
def share_memory(self):
self.node.share_memory_()
return self
def to(self, device: torch.device):
self.node.to(device)
return self
class NegativeSamplingMode(Enum):
# 'binary': Randomly sample negative edges in the graph.
binary = 'binary'
# 'triplet': Randomly sample negative destination nodes for each positive
# source node.
triplet = 'triplet'
@dataclass
class NegativeSampling(CastMixin):
r"""The negative sampling configuration of a
:class:`~torch_geometric.sampler.BaseSampler` when calling
:meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges`.
Args:
mode (str): The negative sampling mode
(:obj:`"binary"` or :obj:`"triplet"`).
If set to :obj:`"binary"`, will randomly sample negative links
from the graph.
If set to :obj:`"triplet"`, will randomly sample negative
destination nodes for each positive source node.
amount (int or float, optional): The ratio of sampled negative edges to
the number of positive edges. (default: :obj:`1`)
weight (torch.Tensor, optional): A node-level vector determining the
sampling of nodes. Does not necessariyl need to sum up to one.
If not given, negative nodes will be sampled uniformly.
(default: :obj:`None`)
"""
mode: NegativeSamplingMode
amount: Union[int, float] = 1
weight: Optional[torch.Tensor] = None
def __init__(
self,
mode: Union[NegativeSamplingMode, str],
amount: Union[int, float] = 1,
weight: Optional[torch.Tensor] = None,
):
self.mode = NegativeSamplingMode(mode)
self.amount = amount
self.weight = weight
if self.amount <= 0:
raise ValueError(f"The attribute 'amount' needs to be positive "
f"for '{self.__class__.__name__}' "
f"(got {self.amount})")
if self.is_triplet():
if self.amount != math.ceil(self.amount):
raise ValueError(f"The attribute 'amount' needs to be an "
f"integer for '{self.__class__.__name__}' "
f"with 'triplet' negative sampling "
f"(got {self.amount}).")
self.amount = math.ceil(self.amount)
def is_binary(self) -> bool:
return self.mode == NegativeSamplingMode.binary
def is_triplet(self) -> bool:
return self.mode == NegativeSamplingMode.triplet
def share_memory(self):
if self.weight is not None:
self.weight.share_memory_()
return self
def to(self, device: torch.device):
if self.weight is not None:
self.weight.to(device)
return self
@dataclass
class EdgeSamplerInput(CastMixin):
r""" The sampling input of
:meth:`~graphlearn_torch.sampler.BaseSampler.sample_from_edges`.
This class corresponds to :class:`~torch_geometric.sampler.EdgeSamplerInput`:
https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/sampler/base.py
Args:
row (torch.Tensor): The source node indices of seed links to start
sampling from.
col (torch.Tensor): The destination node indices of seed links to start
sampling from.
label (torch.Tensor, optional): The label for the seed links.
(default: :obj:`None`).
input_type (Tuple[str, str, str], optional): The input edge type (in
case of sampling in a heterogeneous graph). (default: :obj:`None`).
"""
row: torch.Tensor
col: torch.Tensor
label: Optional[torch.Tensor] = None
input_type: Optional[EdgeType] = None
neg_sampling: Optional[NegativeSampling] = None
def __getitem__(self, index: Union[torch.Tensor, Any]) -> 'EdgeSamplerInput':
if not isinstance(index, torch.Tensor):
index = torch.tensor(index, dtype=torch.long)
index = index.to(self.row.device)
return EdgeSamplerInput(
self.row[index],
self.col[index],
self.label[index] if self.label is not None else None,
self.input_type,
self.neg_sampling
)
def __len__(self):
return self.row.numel()
def share_memory(self):
self.row.share_memory_()
self.col.share_memory_()
if self.label is not None:
self.label.share_memory_()
if self.label is not None:
self.neg_sampling.share_memory()
return self
def to(self, device: torch.device):
self.row.to(device)
self.col.to(device)
if self.label is not None:
self.label.to(device)
if self.label is not None:
self.neg_sampling.to(device)
return self
@dataclass
class SamplerOutput(CastMixin):
r""" The sampling output of a :class:`~graphlearn_torch.sampler.BaseSampler` on
homogeneous graphs.
This class corresponds to :class:`~torch_geometric.sampler.SamplerOutput`:
https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/sampler/base.py
Args:
node (torch.Tensor): The sampled nodes in the original graph.
row (torch.Tensor): The source node indices of the sampled subgraph.
Indices must be re-indexed to :obj:`{ 0, ..., num_nodes - 1 }`
corresponding to the nodes in the :obj:`node` tensor.
col (torch.Tensor): The destination node indices of the sampled subgraph.
Indices must be re-indexed to :obj:`{ 0, ..., num_nodes - 1 }`
corresponding to the nodes in the :obj:`node` tensor.
edge (torch.Tensor, optional): The sampled edges in the original graph.
This tensor is used to obtain edge features from the original
graph. If no edge attributes are present, it may be omitted.
batch (torch.Tensor, optional): The vector to identify the seed node
for each sampled node. Can be present in case of disjoint subgraph
sampling per seed node. (default: :obj:`None`).
device (torch.device, optional): The device that all data of this output
resides in. (default: :obj:`None`).
metadata: (Any, optional): Additional metadata information.
(default: :obj:`None`).
"""
node: torch.Tensor
row: torch.Tensor
col: torch.Tensor
edge: Optional[torch.Tensor] = None
batch: Optional[torch.Tensor] = None
num_sampled_nodes: Optional[Union[List[int], torch.Tensor]] = None
num_sampled_edges: Optional[Union[List[int], torch.Tensor]] = None
device: Optional[torch.device] = None
metadata: Optional[Any] = None
@dataclass
class HeteroSamplerOutput(CastMixin):
r""" The sampling output of a :class:`~graphlearn_torch.sampler.BaseSampler` on
heterogeneous graphs.
This class corresponds to
:class:`~torch_geometric.sampler.HeteroSamplerOutput`:
https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/sampler/base.py
Args:
node (Dict[str, torch.Tensor]): The sampled nodes in the original graph
for each node type.
row (Dict[Tuple[str, str, str], torch.Tensor]): The source node indices
of the sampled subgraph for each edge type. Indices must be re-indexed
to :obj:`{ 0, ..., num_nodes - 1 }` corresponding to the nodes in the
:obj:`node` tensor of the source node type.
col (Dict[Tuple[str, str, str], torch.Tensor]): The destination node
indices of the sampled subgraph for each edge type. Indices must be
re-indexed to :obj:`{ 0, ..., num_nodes - 1 }` corresponding to the nodes
in the :obj:`node` tensor of the destination node type.
edge (Dict[Tuple[str, str, str], torch.Tensor], optional): The sampled
edges in the original graph for each edge type. This tensor is used to
obtain edge features from the original graph. If no edge attributes are
present, it may be omitted. (default: :obj:`None`).
batch (Dict[str, torch.Tensor], optional): The vector to identify the
seed node for each sampled node for each node type. Can be present
in case of disjoint subgraph sampling per seed node.
(default: :obj:`None`).
edge_types: (List[Tuple[str, str, str]], optional): The list of edge types
of the sampled subgraph. (default: :obj:`None`).
input_type: (Union[NodeType, EdgeType], optional): The input type of seed
nodes or edge_label_index.
(default: :obj:`None`).
device (torch.device, optional): The device that all data of this output
resides in. (default: :obj:`None`).
metadata: (Any, optional): Additional metadata information.
(default: :obj:`None`)
"""
node: Dict[NodeType, torch.Tensor]
row: Dict[EdgeType, torch.Tensor]
col: Dict[EdgeType, torch.Tensor]
edge: Optional[Dict[EdgeType, torch.Tensor]] = None
batch: Optional[Dict[NodeType, torch.Tensor]] = None
num_sampled_nodes: Optional[Dict[NodeType, Union[List[int], torch.Tensor]]] = None
num_sampled_edges: Optional[Dict[EdgeType, Union[List[int], torch.Tensor]]] = None
edge_types: Optional[List[EdgeType]] = None
input_type: Optional[Union[NodeType, EdgeType]] = None
device: Optional[torch.device] = None
metadata: Optional[Any] = None
def get_edge_index(self):
edge_index = {k: torch.stack([v, self.col[k]]) for k, v in self.row.items()}
if self.edge_types is not None:
for etype in self.edge_types:
if edge_index.get(etype, None) is None:
edge_index[etype] = \
torch.empty((2, 0), dtype=torch.long).to(self.device)
return edge_index
@dataclass
class NeighborOutput(CastMixin):
r""" The output of sampled neighbor results for a single hop sampling.
Args:
nbr (torch.Tensor): A 1D tensor of all sampled neighborhood node ids.
nbr_num (torch.Tensor): A 1D tensor that identify the number of
neighborhood nodes for each source nodes. Must be the same length as
the source nodes of this sampling hop.
nbr_num (torch.Tensor, optional): The edge ids corresponding to the sampled
edges (from source node to the sampled neighborhood node). Should be the
same length as :obj:`nbr` if provided.
"""
nbr: torch.Tensor
nbr_num: torch.Tensor
edge: Optional[torch.Tensor]
def to(self, device: torch.device):
return NeighborOutput(
nbr=self.nbr.to(device),
nbr_num=self.nbr_num.to(device),
edge=(self.edge.to(device) if self.edge is not None else None)
)
class SamplingType(Enum):
r""" Enum class for sampling types.
"""
NODE = 0
LINK = 1
SUBGRAPH = 2
RANDOM_WALK = 3
@dataclass
class SamplingConfig:
r""" Configuration info for sampling.
"""
sampling_type: SamplingType
num_neighbors: Optional[NumNeighbors]
batch_size: int
shuffle: bool
drop_last: bool
with_edge: bool
collect_features: bool
with_neg: bool
with_weight: bool
edge_dir: Literal['in', 'out']
seed: int
class BaseSampler(ABC):
r""" A base class that initializes a graph sampler and provides
:meth:`sample_from_nodes` and :meth:`sample_from_edges` routines.
This class corresponds to :class:`~torch_geometric.sampler.BaseSampler`:
https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/sampler/base.py
"""
@abstractmethod
def sample_from_nodes(
self,
inputs: NodeSamplerInput,
**kwargs
) -> Union[HeteroSamplerOutput, SamplerOutput]:
r""" Performs sampling from the nodes specified in :obj:`inputs`,
returning a sampled subgraph(egograph) in the specified output format.
Args:
inputs (torch.Tensor): The input data with node indices to start
sampling from.
"""
@abstractmethod
def sample_from_edges(
self,
inputs: EdgeSamplerInput,
**kwargs,
) -> Union[HeteroSamplerOutput, SamplerOutput]:
r""" Performs sampling from the edges specified in :obj:`inputs`,
returning a sampled subgraph(egograph) in the specified output format.
Args:
inputs (EdgeSamplerInput): The input data for sampling from edges
including the (1) source node indices, the (2) destination node
indices, the (3) optional edge labels and the (4) input edge type.
"""
@abstractmethod
def subgraph(
self,
inputs: NodeSamplerInput,
) -> SamplerOutput:
r""" Induce an enclosing subgraph based on inputs and their neighbors(if
num_neighbors is not None).
Args:
inputs (torch.Tensor): The input data with node indices to induce subgraph
from.
Returns:
The sampled unique nodes, relabeled rows and cols, original edge_ids,
and a mapping from indices in `inputs` to new indices in output nodes,
i.e. nodes[mapping] = inputs.
"""
class RemoteSamplerInput(ABC):
"""A base class that provides the `to_local_sampler_input` method for the server
to obtain the sampler input.
"""
@abstractmethod
def to_local_sampler_input(
self,
dataset,
**kwargs
) -> Union[NodeSamplerInput, EdgeSamplerInput]:
r"""
Abstract method to convert the sampler input to local format.
"""
class RemoteNodePathSamplerInput(RemoteSamplerInput):
r"""RemoteNodePathSamplerInput passes the node path to the server, where the server
can load node seeds from it.
"""
def __init__(self, node_path: str, input_type: str ) -> None:
self.node_path = node_path
self.input_type = input_type
def to_local_sampler_input(
self,
dataset,
**kwargs,
) -> NodeSamplerInput:
node = torch.load(self.node_path)
return NodeSamplerInput(node=node, input_type=self.input_type)
class RemoteNodeSplitSamplerInput(RemoteSamplerInput):
r"""RemoteNodeSplitSamplerInput passes the split category to the server and the server
loads seeds from the dataset.
"""
def __init__(self, split: Split, input_type: str ) -> None:
self.split = split
self.input_type = input_type
def to_local_sampler_input(
self,
dataset,
**kwargs,
) -> NodeSamplerInput:
if self.split == Split.train:
idx = dataset.train_idx
elif self.split == Split.valid:
idx = dataset.val_idx
elif self.split == Split.test:
idx = dataset.test_idx
if isinstance(idx, dict):
idx = idx[self.input_type]
return NodeSamplerInput(node=idx, input_type=self.input_type)