def find_all_devices()

in optimum/amd/brevitas/accelerate_utils.py [0:0]


def find_all_devices(data):
    """
    Finds the device on which a nested dict/list/tuple of tensors lies (assuming they are all on the same device).
    Args:
        (nested list/tuple/dictionary of `torch.Tensor`): The data we want to know the device of.
    """
    if isinstance(data, Mapping):
        devices = []
        for obj in data.values():
            device = find_all_devices(obj)
            if device is not None:
                devices.extend(device)
        return devices
    elif isinstance(data, (tuple, list)):
        devices = []
        for obj in data:
            device = find_all_devices(obj)
            if device is not None:
                devices.extend(device)
        return devices
    elif isinstance(data, torch.Tensor):
        return [(data, str(data.device))]