def get_target_class_for_model_and_task()

in src/optimum/nvidia/pipelines/__init__.py [0:0]


def get_target_class_for_model_and_task(task: str, architecture: str) -> Optional[Type]:
    task_ = SUPPORTED_MODEL_WITH_TASKS.get(task, None)
    if not task_:
        raise NotImplementedError(f"Task {task} is not supported yet.")

    target = task_.get(architecture, None)

    if not target:
        raise NotImplementedError(
            f"Architecture {architecture} is not supported for task {task}. "
            f"Only the following architectures are: {list(task_.keys())}"
        )

    return target