def task_type_from_model_config()

in eland/ml/pytorch/transformers.py [0:0]


def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str]:
    if model_config.architectures is None:
        if model_config.name_or_path.startswith("sentence-transformers/"):
            return "text_embedding"
        return None
    potential_task_types: Set[str] = set()
    for architecture in model_config.architectures:
        for substr, task_type in ARCHITECTURE_TO_TASK_TYPE.items():
            if substr in architecture:
                for t in task_type:
                    potential_task_types.add(t)
    if len(potential_task_types) == 0:
        if model_config.name_or_path.startswith("sentence-transformers/"):
            return "text_embedding"
        return None
    if (
        "text_classification" in potential_task_types
        and model_config.id2label
        and len(model_config.id2label) == 1
    ):
        return "text_similarity"
    if len(potential_task_types) > 1:
        if "zero_shot_classification" in potential_task_types:
            if model_config.label2id:
                labels = set([x.lower() for x in model_config.label2id.keys()])
                if len(labels.difference(ZERO_SHOT_LABELS)) == 0:
                    return "zero_shot_classification"
            return "text_classification"
        if "text_embedding" in potential_task_types:
            if model_config.name_or_path.startswith("sentence-transformers/"):
                return "text_embedding"
            return "fill_mask"
    return potential_task_types.pop()