def convert_to_tensor()

in graphlearn_torch/python/utils/tensor.py [0:0]


def convert_to_tensor(data: Any, dtype: torch.dtype = None):
  r""" Convert the input data to a tensor based type.
  """
  if isinstance(data, dict):
    new_data = {}
    for k, v in data.items():
      new_data[k] = convert_to_tensor(v, dtype)
    return new_data
  if isinstance(data, list):
    new_data = []
    for v in data:
      new_data.append(convert_to_tensor(v, dtype))
    return new_data
  if isinstance(data, tuple):
    return tuple(convert_to_tensor(list(data), dtype))
  if isinstance(data, torch.Tensor):
    return data.type(dtype) if dtype is not None else data
  if isinstance(data, numpy.ndarray):
    return (
      torch.from_numpy(data).type(dtype) if dtype is not None
      else torch.from_numpy(data)
    )
  return data