graphlearn_torch/python/loader/neighbor_loader.py (60 lines of code) (raw):
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Optional
import torch
from ..loader import NodeLoader
from ..data import Dataset
from ..sampler import NeighborSampler, NodeSamplerInput
from ..typing import InputNodes, NumNeighbors
class NeighborLoader(NodeLoader):
r"""A data loader that performs node neighbor sampling for mini-batch training
of GNNs on large-scale graphs.
Args:
data (Dataset): The `graphlearn_torch.data.Dataset` object.
num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]): The
number of neighbors to sample for each node in each iteration.
In heterogeneous graphs, may also take in a dictionary denoting
the amount of neighbors to sample for each individual edge type.
If an entry is set to :obj:`-1`, all neighbors will be included.
input_nodes (torch.Tensor or str or Tuple[str, torch.Tensor]): The
indices of nodes for which neighbors are sampled to create
mini-batches.
Needs to be either given as a :obj:`torch.LongTensor` or
:obj:`torch.BoolTensor`.
In heterogeneous graphs, needs to be passed as a tuple that holds
the node type and node indices.
batch_size (int): How many samples per batch to load (default: ``1``).
shuffle (bool): Set to ``True`` to have the data reshuffled at every
epoch (default: ``False``).
drop_last (bool): Set to ``True`` to drop the last incomplete batch, if
the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last
batch will be smaller. (default: ``False``).
with_edge (bool): Set to ``True`` to sample with edge ids and also include
them in the sampled results. (default: ``False``).
strategy: (str): Set sampling strategy for the default neighbor sampler
provided by graphlearn-torch. (default: ``"random"``).
as_pyg_v1 (bool): Set to ``True`` to return result as the NeighborSampler
in PyG v1. (default: ``False``).
"""
def __init__(
self,
data: Dataset,
num_neighbors: NumNeighbors,
input_nodes: InputNodes,
neighbor_sampler: Optional[NeighborSampler] = None,
batch_size: int = 1,
shuffle: bool = False,
drop_last: bool = False,
with_edge: bool = False,
with_weight: bool = False,
strategy: str = 'random',
device: torch.device = torch.device('cuda:0'),
as_pyg_v1: bool = False,
seed: Optional[int] = None,
**kwargs
):
if neighbor_sampler is None:
neighbor_sampler = NeighborSampler(
data.graph,
num_neighbors=num_neighbors,
strategy=strategy,
with_edge=with_edge,
with_weight=with_weight,
device=device,
edge_dir=data.edge_dir,
seed=seed
)
self.as_pyg_v1 = as_pyg_v1
self.edge_dir = data.edge_dir
super().__init__(
data=data,
node_sampler=neighbor_sampler,
input_nodes=input_nodes,
device=device,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
**kwargs,
)
def __next__(self):
seeds = self._seeds_iter._next_data().to(self.device)
if not self.as_pyg_v1:
inputs = NodeSamplerInput(
node=seeds,
input_type=self._input_type
)
out = self.sampler.sample_from_nodes(inputs)
result = self._collate_fn(out)
else:
return self.sampler.sample_pyg_v1(seeds)
return result