in src/sagemaker_huggingface_inference_toolkit/transformers_utils.py [0:0]
def infer_task_from_model_architecture(model_config_path: str, architecture_index=0) -> str:
"""
Infer task from `config.json` of trained model. It is not guaranteed to the detect, e.g. some models implement multiple architectures or
trainend on different tasks https://huggingface.co/facebook/bart-large/blob/main/config.json. Should work for every on Amazon SageMaker fine-tuned model.
It is always recommended to set the task through the env var `TASK`.
"""
with open(model_config_path, "r") as config_file:
config = json.loads(config_file.read())
architecture = config.get("architectures", [None])[architecture_index]
task = None
for arch_options in ARCHITECTURES_2_TASK:
if architecture.endswith(arch_options):
task = ARCHITECTURES_2_TASK[arch_options]
if task is None:
raise ValueError(
f"Task couldn't be inferenced from {architecture}."
f"Inference Toolkit can only inference tasks from architectures ending with {list(ARCHITECTURES_2_TASK.keys())}."
"Use env `HF_TASK` to define your task."
)
# set env to work with
os.environ["HF_TASK"] = task
return task