in eland/ml/pytorch/transformers.py [0:0]
def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str]:
if model_config.architectures is None:
if model_config.name_or_path.startswith("sentence-transformers/"):
return "text_embedding"
return None
potential_task_types: Set[str] = set()
for architecture in model_config.architectures:
for substr, task_type in ARCHITECTURE_TO_TASK_TYPE.items():
if substr in architecture:
for t in task_type:
potential_task_types.add(t)
if len(potential_task_types) == 0:
if model_config.name_or_path.startswith("sentence-transformers/"):
return "text_embedding"
return None
if (
"text_classification" in potential_task_types
and model_config.id2label
and len(model_config.id2label) == 1
):
return "text_similarity"
if len(potential_task_types) > 1:
if "zero_shot_classification" in potential_task_types:
if model_config.label2id:
labels = set([x.lower() for x in model_config.label2id.keys()])
if len(labels.difference(ZERO_SHOT_LABELS)) == 0:
return "zero_shot_classification"
return "text_classification"
if "text_embedding" in potential_task_types:
if model_config.name_or_path.startswith("sentence-transformers/"):
return "text_embedding"
return "fill_mask"
return potential_task_types.pop()