in optimum/amd/brevitas/data_utils.py [0:0]
def recursive_to_device(tensor_or_iterable: Union[Iterable, torch.Tensor], device) -> None:
if isinstance(tensor_or_iterable, torch.Tensor):
return tensor_or_iterable.to(device)
elif isinstance(tensor_or_iterable, tuple): # Special handling of tuples, since they are immutable
tmp_list = []
for i in tensor_or_iterable:
tmp_list.append(recursive_to_device(i, device))
return tuple(tmp_list)
elif isinstance(tensor_or_iterable, Iterable):
for i in tensor_or_iterable:
tensor_or_iterable[i] = recursive_to_device(i, device)
return tensor_or_iterable
else:
raise ValueError(f"Cannot move {type(tensor_or_iterable)} to {device}")