in optimum/onnxruntime/modeling_diffusion.py [0:0]
def _get_task_ort_class(mapping, pipeline_class_name):
def _get_model_name(pipeline_class_name):
for ort_pipelines_mapping in SUPPORTED_ORT_PIPELINES_MAPPINGS:
for model_name, ort_pipeline_class in ort_pipelines_mapping.items():
if (
ort_pipeline_class.__name__ == pipeline_class_name
or ort_pipeline_class.auto_model_class.__name__ == pipeline_class_name
):
return model_name
model_name = _get_model_name(pipeline_class_name)
if model_name is not None:
task_class = mapping.get(model_name, None)
if task_class is not None:
return task_class
raise ValueError(f"ORTPipelineForTask can't find a pipeline linked to {pipeline_class_name} for {model_name}")