in graphlearn_torch/python/partition/partition_book.py [0:0]
def __init__(self, partition_ranges: List[Tuple[int, int]], partition_idx: int):
if not all(r[0] < r[1] for r in partition_ranges):
raise ValueError("All partition ranges must have start < end")
if not all(r1[1] == r2[0] for r1, r2 in zip(partition_ranges[:-1], partition_ranges[1:])):
raise ValueError("Partition ranges must be continuous")
self.partition_bounds = torch.tensor(
[end for _, end in partition_ranges], dtype=torch.long)
self.partition_idx = partition_idx
self._id2index = OffsetId2Index(partition_ranges[partition_idx][0])