in graphlearn_torch/python/utils/tensor.py [0:0]
def apply_to_all_tensor(data: Any, tensor_method, *args, **kwargs):
r""" Apply the specified method to all tensors contained by the
input data recursively.
"""
if isinstance(data, dict):
new_data = {}
for k, v in data.items():
new_data[k] = apply_to_all_tensor(v, tensor_method, *args, **kwargs)
return new_data
if isinstance(data, list):
new_data = []
for v in data:
new_data.append(apply_to_all_tensor(v, tensor_method, *args, **kwargs))
return new_data
if isinstance(data, tuple):
return tuple(apply_to_all_tensor(list(data), tensor_method, *args, **kwargs))
if isinstance(data, torch.Tensor):
return tensor_method(data, *args, **kwargs)
return data