in optimum/amd/brevitas/accelerate_utils.py [0:0]
def find_all_devices(data):
"""
Finds the device on which a nested dict/list/tuple of tensors lies (assuming they are all on the same device).
Args:
(nested list/tuple/dictionary of `torch.Tensor`): The data we want to know the device of.
"""
if isinstance(data, Mapping):
devices = []
for obj in data.values():
device = find_all_devices(obj)
if device is not None:
devices.extend(device)
return devices
elif isinstance(data, (tuple, list)):
devices = []
for obj in data:
device = find_all_devices(obj)
if device is not None:
devices.extend(device)
return devices
elif isinstance(data, torch.Tensor):
return [(data, str(data.device))]