in utils/ryzenai/notification_service.py [0:0]
def infer_model_id(model):
model_name_replacement = model.replace(".", "_").replace("-", "_")
if "timm" in model:
all_model_names = list(tu.PYTORCH_TIMM_MODEL["default-timm-config"].keys())
elif "amd" in model:
all_model_names = (
tu.RYZEN_PREQUANTIZED_MODEL_IMAGE_CLASSIFICATION
+ list(tu.RYZEN_PREQUANTIZED_MODEL_OBJECT_DETECTION.values())
+ tu.RYZEN_PREQUANTIZED_MODEL_IMAGE_SEGMENTATION
+ tu.RYZEN_PREQUANTIZED_MODEL_IMAGE_TO_IMAGE
+ tu.RYZEN_PREQUANTIZED_MODEL_CUSTOM_TASKS
)
else:
return model
for model_name in all_model_names:
if model_name.replace(".", "_").replace("-", "_") == model_name_replacement:
return model_name
return model