graphlearn_torch/python/partition/partition_book.py (39 lines of code) (raw):
import torch
from typing import List, Tuple
from .base import PartitionBook
class RangePartitionBook(PartitionBook):
r"""A class for managing range-based partitions of consecutive IDs.
Suitable when IDs within each partition are consecutive.
Args:
partition_ranges (List[Tuple[int, int]]): A list of tuples representing
the start and end (exclusive) of each partition range.
partition_idx (int): The index of the current partition.
Example:
>>> partition_ranges = [(0, 10), (10, 20), (20, 30)]
>>> range_pb = RangePartitionBook(partition_ranges, partition_idx=1)
>>> indices = torch.tensor([0, 5, 10, 15, 20, 25])
>>> partition_ids = range_pb[indices]
>>> print(partition_ids)
tensor([0, 0, 1, 1, 2, 2])
"""
def __init__(self, partition_ranges: List[Tuple[int, int]], partition_idx: int):
if not all(r[0] < r[1] for r in partition_ranges):
raise ValueError("All partition ranges must have start < end")
if not all(r1[1] == r2[0] for r1, r2 in zip(partition_ranges[:-1], partition_ranges[1:])):
raise ValueError("Partition ranges must be continuous")
self.partition_bounds = torch.tensor(
[end for _, end in partition_ranges], dtype=torch.long)
self.partition_idx = partition_idx
self._id2index = OffsetId2Index(partition_ranges[partition_idx][0])
def __getitem__(self, indices: torch.Tensor) -> torch.Tensor:
return torch.searchsorted(self.partition_bounds, indices, right=True)
@property
def device(self):
return self.partition_bounds.device
@property
def id2index(self):
return self._id2index
def id_filter(self, node_pb: PartitionBook, partition_idx: int):
start = self.partition_bounds[partition_idx-1] if partition_idx > 0 else 0
end = self.partition_bounds[partition_idx]
return torch.arange(start, end)
class OffsetId2Index:
r"""
Convert global IDs to local indices by subtracting a specified offset.
"""
def __init__(self, offset: int):
self.offset = offset
def __getitem__(self, ids: torch.Tensor) -> torch.Tensor:
local_indices = ids - self.offset
return local_indices
def to(self, device):
# device is always same as the input ids
return self
class GLTPartitionBook(PartitionBook, torch.Tensor):
r""" A partition book of graph nodes or edges.
"""
def __getitem__(self, indices) -> torch.Tensor:
return torch.Tensor.__getitem__(self, indices)