def align_input()

in optimum/amd/brevitas/accelerate_utils.py [0:0]


def align_input(model, device_map):
    if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
        main_device = "cpu"
    else:
        main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
    hook = AlignDevicesHook(execution_device=main_device, io_same_device=True, skip_keys=None, tied_params_map=None)
    add_hook_to_module(model, hook)
    return model