in optimum/amd/brevitas/accelerate_utils.py [0:0]
def align_input(model, device_map):
if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
main_device = "cpu"
else:
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
hook = AlignDevicesHook(execution_device=main_device, io_same_device=True, skip_keys=None, tied_params_map=None)
add_hook_to_module(model, hook)
return model