def apply_to_all_tensor()

in graphlearn_torch/python/utils/tensor.py [0:0]


def apply_to_all_tensor(data: Any, tensor_method, *args, **kwargs):
  r""" Apply the specified method to all tensors contained by the
  input data recursively.
  """
  if isinstance(data, dict):
    new_data = {}
    for k, v in data.items():
      new_data[k] = apply_to_all_tensor(v, tensor_method, *args, **kwargs)
    return new_data
  if isinstance(data, list):
    new_data = []
    for v in data:
      new_data.append(apply_to_all_tensor(v, tensor_method, *args, **kwargs))
    return new_data
  if isinstance(data, tuple):
    return tuple(apply_to_all_tensor(list(data), tensor_method, *args, **kwargs))
  if isinstance(data, torch.Tensor):
    return tensor_method(data, *args, **kwargs)
  return data