def infer_task_from_model_architecture()

in src/sagemaker_huggingface_inference_toolkit/transformers_utils.py [0:0]


def infer_task_from_model_architecture(model_config_path: str, architecture_index=0) -> str:
    """
    Infer task from `config.json` of trained model. It is not guaranteed to the detect, e.g. some models implement multiple architectures or
    trainend on different tasks https://huggingface.co/facebook/bart-large/blob/main/config.json. Should work for every on Amazon SageMaker fine-tuned model.
    It is always recommended to set the task through the env var `TASK`.
    """
    with open(model_config_path, "r") as config_file:
        config = json.loads(config_file.read())
        architecture = config.get("architectures", [None])[architecture_index]

    task = None
    for arch_options in ARCHITECTURES_2_TASK:
        if architecture.endswith(arch_options):
            task = ARCHITECTURES_2_TASK[arch_options]

    if task is None:
        raise ValueError(
            f"Task couldn't be inferenced from {architecture}."
            f"Inference Toolkit can only inference tasks from architectures ending with {list(ARCHITECTURES_2_TASK.keys())}."
            "Use env `HF_TASK` to define your task."
        )
    # set env to work with
    os.environ["HF_TASK"] = task
    return task