graphlearn_torch/python/loader/link_loader.py (114 lines of code) (raw):
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Tuple, Union, Optional, Literal
import torch
from .transform import to_data, to_hetero_data
from ..utils import convert_to_tensor
from ..data import Dataset
from ..sampler import (
BaseSampler,
EdgeSamplerInput,
NegativeSampling,
SamplerOutput,
HeteroSamplerOutput
)
from ..typing import InputEdges, reverse_edge_type
class LinkLoader(object):
r"""A data loader that performs mini-batch sampling from link information,
using a generic :class:`~graphlearn_torch.sampler.BaseSampler`
implementation that defines a
:meth:`~graphlearn_torch.sampler.BaseSampler.sample_from_edges` function and
is supported on the provided input :obj:`data` object.
.. note::
Negative sampling for triplet case is currently implemented in an
approximate way, *i.e.* negative edges may contain false negatives.
Args:
data (Dataset): The `graphlearn_torch.data.Dataset` object.
link_sampler (graphlearn_torch.sampler.BaseSampler): The sampler
implementation to be used with this loader.
Needs to implement
:meth:`~graphlearn_torch.sampler.BaseSampler.sample_from_edges`.
The sampler implementation must be compatible with the input
:obj:`data` object.
edge_label_index (Tensor or EdgeType or Tuple[EdgeType, Tensor]):
The edge indices, holding source and destination nodes to start
sampling from.
If set to :obj:`None`, all edges will be considered.
In heterogeneous graphs, needs to be passed as a tuple that holds
the edge type and corresponding edge indices.
(default: :obj:`None`)
edge_label (Tensor, optional): The labels of edge indices from which to
start sampling from. Must be the same length as
the :obj:`edge_label_index`. (default: :obj:`None`)
neg_sampling (NegativeSampling, optional): The negative sampling
configuration.
For negative sampling mode :obj:`"binary"`, samples can be accessed
via the attributes :obj:`edge_label_index` and :obj:`edge_label` in
the respective edge type of the returned mini-batch.
In case :obj:`edge_label` does not exist, it will be automatically
created and represents a binary classification task (:obj:`0` =
negative edge, :obj:`1` = positive edge).
In case :obj:`edge_label` does exist, it has to be a categorical
label from :obj:`0` to :obj:`num_classes - 1`.
After negative sampling, label :obj:`0` represents negative edges,
and labels :obj:`1` to :obj:`num_classes` represent the labels of
positive edges.
Note that returned labels are of type :obj:`torch.float` for binary
classification (to facilitate the ease-of-use of
:meth:`F.binary_cross_entropy`) and of type
:obj:`torch.long` for multi-class classification (to facilitate the
ease-of-use of :meth:`F.cross_entropy`).
For negative sampling mode :obj:`"triplet"`, samples can be
accessed via the attributes :obj:`src_index`, :obj:`dst_pos_index`
and :obj:`dst_neg_index` in the respective node types of the
returned mini-batch.
:obj:`edge_label` needs to be :obj:`None` for :obj:`"triplet"`
negative sampling mode.
If set to :obj:`None`, no negative sampling strategy is applied.
(default: :obj:`None`)
device (torch.device, optional): The device to put the data on.
If set to :obj:`None`, the CPU is used.
edge_dir (str:["in", "out"]): The edge direction for sampling.
Can be either :str:`"out"` or :str:`"in"`.
(default: :str:`"out"`)
**kwargs (optional): Additional arguments of
:class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
:obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
"""
def __init__(
self,
data: Dataset,
link_sampler: BaseSampler,
edge_label_index: InputEdges = None,
edge_label: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None,
device: torch.device = torch.device('cuda:0'),
edge_dir: Literal['out', 'in'] = 'out',
**kwargs,
):
# Get edge type (or `None` for homogeneous graphs):
input_type, edge_label_index = get_edge_label_index(
data, edge_label_index)
self.data = data
self.link_sampler = link_sampler
self.neg_sampling = NegativeSampling.cast(neg_sampling)
self.device = device
self.edge_dir = edge_dir
if (self.neg_sampling is not None and self.neg_sampling.is_binary()
and edge_label is not None and edge_label.min() == 0):
# Increment labels such that `zero` now denotes "negative".
edge_label = edge_label + 1
if (self.neg_sampling is not None and self.neg_sampling.is_triplet()
and edge_label is not None):
raise ValueError("'edge_label' needs to be undefined for "
"'triplet'-based negative sampling. Please use "
"`src_index`, `dst_pos_index` and "
"`neg_pos_index` of the returned mini-batch "
"instead to differentiate between positive and "
"negative samples.")
self.input_data = EdgeSamplerInput(
row=edge_label_index[0].clone(),
col=edge_label_index[1].clone(),
label=edge_label,
input_type=input_type,
neg_sampling=self.neg_sampling,
)
input_index = range(len(edge_label_index[0]))
self._seed_loader = torch.utils.data.DataLoader(input_index, **kwargs)
def __iter__(self):
self._seeds_iter = iter(self._seed_loader)
return self
def __next__(self):
seeds = self._seeds_iter._next_data().to(self.device)
# Currently, we support the out-edge sampling manner, so we reverse the
# direction of src and dst for the output so that features of the sampled
# nodes during training can be aggregated from k-hop to (k-1)-hop nodes.
sampler_out = self.link_sampler.sample_from_edges(self.input_data[seeds])
result = self._collate_fn(sampler_out)
return result
def _collate_fn(self, sampler_out: Union[SamplerOutput, HeteroSamplerOutput]):
r"""format sampler output to Data/HeteroData
For the out-edge sampling scheme (i.e. the direction of edges in
the output is inverse to the original graph), we put the reversed
edge_label_index into the (dst, rev_to, src) subgraph for
HeteroSamplerOutput and (dst, to, src) for SamplerOutput.
However, for the in-edge sampling scheme (i.e. the direction of edges
in the output is the same as the original graph), we do not need to
reverse the edge type of the sampler_out.
"""
if isinstance(sampler_out, SamplerOutput):
x = self.data.node_features[sampler_out.node]
if self.data.edge_features is not None and sampler_out.edge is not None:
edge_attr = self.data.edge_features[sampler_out.edge]
else:
edge_attr = None
res_data = to_data(sampler_out,
node_feats=x,
edge_feats=edge_attr,
)
else: # hetero
x_dict = {}
x_dict = {ntype : self.data.get_node_feature(ntype)[ids.to(torch.int64)] for ntype, ids in sampler_out.node.items()}
edge_attr_dict = {}
if sampler_out.edge is not None:
for etype, eids in sampler_out.edge.items():
if self.edge_dir == 'out':
efeat = self.data.get_edge_feature(reverse_edge_type(etype))
elif self.edge_dir == 'in':
efeat = self.data.get_edge_feature(etype)
if efeat is not None:
edge_attr_dict[etype] = efeat[eids.to(torch.int64)]
res_data = to_hetero_data(sampler_out,
node_feat_dict=x_dict,
edge_feat_dict=edge_attr_dict,
edge_dir=self.edge_dir,
)
return res_data
def __repr__(self) -> str:
return f'{self.__class__.__name__}()'
def get_edge_label_index(
data: Dataset,
edge_label_index: InputEdges
) -> Tuple[Optional[str], torch.Tensor]:
edge_type = None
# # Need the edge index in COO for LinkNeighborLoader:
def _get_edge_index(edge_type):
row, col, _, _ = data.get_graph(edge_type).topo.to_coo()
return (row, col)
if not isinstance(edge_label_index, Tuple):
if edge_label_index is None:
return None, _get_edge_index(edge_type)
return None, convert_to_tensor(edge_label_index)
if isinstance(edge_label_index[0], str):
edge_type = edge_label_index
return edge_type, _get_edge_index(edge_type)
assert len(edge_label_index) == 2
edge_type, edge_label_index = convert_to_tensor(edge_label_index)
if edge_label_index is None:
row, col, _, _ = data.get_graph(edge_type).topo.to_coo()
return edge_type, (row, col)
return edge_type, edge_label_index