def infer_model_id()

in utils/ryzenai/notification_service.py [0:0]


def infer_model_id(model):
    model_name_replacement = model.replace(".", "_").replace("-", "_")

    if "timm" in model:
        all_model_names = list(tu.PYTORCH_TIMM_MODEL["default-timm-config"].keys())
    elif "amd" in model:
        all_model_names = (
            tu.RYZEN_PREQUANTIZED_MODEL_IMAGE_CLASSIFICATION
            + list(tu.RYZEN_PREQUANTIZED_MODEL_OBJECT_DETECTION.values())
            + tu.RYZEN_PREQUANTIZED_MODEL_IMAGE_SEGMENTATION
            + tu.RYZEN_PREQUANTIZED_MODEL_IMAGE_TO_IMAGE
            + tu.RYZEN_PREQUANTIZED_MODEL_CUSTOM_TASKS
        )
    else:
        return model

    for model_name in all_model_names:
        if model_name.replace(".", "_").replace("-", "_") == model_name_replacement:
            return model_name

    return model