graphlearn_torch/python/utils/topo.py (46 lines of code) (raw):
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Optional, Tuple
import torch
import torch_sparse
def ptr2ind(ptr: torch.Tensor) -> torch.Tensor:
r""" Convert an index pointer tensor to an indice tensor.
"""
ind = torch.arange(ptr.numel() - 1, device=ptr.device)
return ind.repeat_interleave(ptr[1:] - ptr[:-1])
def coo_to_csr(
row: torch.Tensor,
col: torch.Tensor,
edge_id: Optional[torch.Tensor] = None,
edge_weight: Optional[torch.Tensor] = None,
node_sizes: Optional[Tuple[int, int]] = None
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
r""" Tranform edge index from COO to CSR.
Args:
row (torch.Tensor): The row indices.
col (torch.Tensor): The column indices.
edge_id (torch.Tensor, optional): The edge ids corresponding to the input
edge index.
edge_weight (torch.Tensor, optional): The edge weights corresponding to the
input edge index.
node_sizes (Tuple[int, int], optional): The number of nodes in row and col.
"""
if node_sizes is None:
node_sizes = (int(row.max()) + 1, int(col.max()) + 1)
assert len(node_sizes) == 2
assert row.numel() == col.numel()
if edge_id is not None:
assert edge_id.numel() == row.numel()
adj_t = torch_sparse.SparseTensor(
row=row, col=col, value=edge_id, sparse_sizes=node_sizes
)
edge_ids, edge_weights = adj_t.storage.value(), None
if edge_weight is not None:
assert edge_weight.numel() == row.numel()
adj_w = torch_sparse.SparseTensor(
row=row, col=col, value=edge_weight, sparse_sizes=node_sizes
)
edge_weights = adj_w.storage.value()
return adj_t.storage.rowptr(), adj_t.storage.col(), edge_ids, edge_weights
def coo_to_csc(
row: torch.Tensor,
col: torch.Tensor,
edge_id: Optional[torch.Tensor] = None,
edge_weight: Optional[torch.Tensor] = None,
node_sizes: Optional[Tuple[int, int]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
r""" Tranform edge index from COO to CSC.
Args:
row (torch.Tensor): The row indices.
col (torch.Tensor): The column indices.
edge_id (torch.Tensor, optional): The edge ids corresponding to the input
edge index.
edge_weight (torch.Tensor, optional): The edge weights corresponding to the
input edge index.
node_sizes (Tuple[int, int], optional): The number of nodes in row and col.
"""
if node_sizes is not None:
node_sizes = (node_sizes[1], node_sizes[0])
r_colptr, r_row, r_edge_id, r_edge_weight = coo_to_csr(
col, row, edge_id, edge_weight, node_sizes
)
return r_row, r_colptr, r_edge_id, r_edge_weight