in graphlearn_torch/python/utils/tensor.py [0:0]
def convert_to_tensor(data: Any, dtype: torch.dtype = None):
r""" Convert the input data to a tensor based type.
"""
if isinstance(data, dict):
new_data = {}
for k, v in data.items():
new_data[k] = convert_to_tensor(v, dtype)
return new_data
if isinstance(data, list):
new_data = []
for v in data:
new_data.append(convert_to_tensor(v, dtype))
return new_data
if isinstance(data, tuple):
return tuple(convert_to_tensor(list(data), dtype))
if isinstance(data, torch.Tensor):
return data.type(dtype) if dtype is not None else data
if isinstance(data, numpy.ndarray):
return (
torch.from_numpy(data).type(dtype) if dtype is not None
else torch.from_numpy(data)
)
return data