graphlearn_torch/python/utils/tensor.py (62 lines of code) (raw):

# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from typing import Any, List, Union import numpy import torch def tensor_equal_with_device(lhs: torch.Tensor, rhs: torch.Tensor): r""" Check whether the data and device of two tensors are same. """ if lhs.device == rhs.device: return torch.equal(lhs, rhs) return False def id2idx(ids: Union[List[int], torch.Tensor]): r""" Get tensor of mapping from id to its original index. """ if not isinstance(ids, torch.Tensor): ids = torch.tensor(ids, dtype=torch.int64) ids = ids.to(torch.int64) max_id = torch.max(ids).item() id2idx = torch.zeros(max_id + 1, dtype=torch.int64, device=ids.device) id2idx[ids] = torch.arange(ids.size(0), dtype=torch.int64, device=ids.device) return id2idx def convert_to_tensor(data: Any, dtype: torch.dtype = None): r""" Convert the input data to a tensor based type. """ if isinstance(data, dict): new_data = {} for k, v in data.items(): new_data[k] = convert_to_tensor(v, dtype) return new_data if isinstance(data, list): new_data = [] for v in data: new_data.append(convert_to_tensor(v, dtype)) return new_data if isinstance(data, tuple): return tuple(convert_to_tensor(list(data), dtype)) if isinstance(data, torch.Tensor): return data.type(dtype) if dtype is not None else data if isinstance(data, numpy.ndarray): return ( torch.from_numpy(data).type(dtype) if dtype is not None else torch.from_numpy(data) ) return data def apply_to_all_tensor(data: Any, tensor_method, *args, **kwargs): r""" Apply the specified method to all tensors contained by the input data recursively. """ if isinstance(data, dict): new_data = {} for k, v in data.items(): new_data[k] = apply_to_all_tensor(v, tensor_method, *args, **kwargs) return new_data if isinstance(data, list): new_data = [] for v in data: new_data.append(apply_to_all_tensor(v, tensor_method, *args, **kwargs)) return new_data if isinstance(data, tuple): return tuple(apply_to_all_tensor(list(data), tensor_method, *args, **kwargs)) if isinstance(data, torch.Tensor): return tensor_method(data, *args, **kwargs) return data def share_memory(data: Any): r""" Share memory for all tensors contained by the input data. """ return apply_to_all_tensor(data, torch.Tensor.share_memory_) def squeeze(data: Any): r""" Squeeze all tensors contained by the input data. """ return apply_to_all_tensor(data, torch.Tensor.squeeze)