def _create_traceable_model()

in eland/ml/pytorch/transformers.py [0:0]


    def _create_traceable_model(self) -> _TransformerTraceableModel:
        if self._task_type == "auto":
            model = transformers.AutoModel.from_pretrained(
                self._model_id, token=self._access_token, torchscript=True
            )
            maybe_task_type = task_type_from_model_config(model.config)
            if maybe_task_type is None:
                raise TaskTypeError(
                    f"Unable to automatically determine task type for model {self._model_id}, please supply task type: {SUPPORTED_TASK_TYPES_NAMES}"
                )
            else:
                self._task_type = maybe_task_type

        if self._task_type == "text_expansion":
            model = transformers.AutoModelForMaskedLM.from_pretrained(
                self._model_id, token=self._access_token, torchscript=True
            )
            model = _DistilBertWrapper.try_wrapping(model)
            return _TraceableTextExpansionModel(self._tokenizer, model)

        if self._task_type == "fill_mask":
            model = transformers.AutoModelForMaskedLM.from_pretrained(
                self._model_id, token=self._access_token, torchscript=True
            )
            model = _DistilBertWrapper.try_wrapping(model)
            return _TraceableFillMaskModel(self._tokenizer, model)

        elif self._task_type == "ner":
            model = transformers.AutoModelForTokenClassification.from_pretrained(
                self._model_id, token=self._access_token, torchscript=True
            )
            model = _DistilBertWrapper.try_wrapping(model)
            return _TraceableNerModel(self._tokenizer, model)

        elif self._task_type == "text_classification":
            model = transformers.AutoModelForSequenceClassification.from_pretrained(
                self._model_id, token=self._access_token, torchscript=True
            )
            model = _DistilBertWrapper.try_wrapping(model)
            return _TraceableTextClassificationModel(self._tokenizer, model)

        elif self._task_type == "text_embedding":
            model = _DPREncoderWrapper.from_pretrained(
                self._model_id, token=self._access_token
            )
            if not model:
                model = _SentenceTransformerWrapperModule.from_pretrained(
                    self._model_id, self._tokenizer, token=self._access_token
                )
            return _TraceableTextEmbeddingModel(self._tokenizer, model)

        elif self._task_type == "zero_shot_classification":
            model = transformers.AutoModelForSequenceClassification.from_pretrained(
                self._model_id, token=self._access_token, torchscript=True
            )
            model = _DistilBertWrapper.try_wrapping(model)
            return _TraceableZeroShotClassificationModel(self._tokenizer, model)

        elif self._task_type == "question_answering":
            model = _QuestionAnsweringWrapperModule.from_pretrained(
                self._model_id, token=self._access_token
            )
            return _TraceableQuestionAnsweringModel(self._tokenizer, model)

        elif self._task_type == "text_similarity":
            model = transformers.AutoModelForSequenceClassification.from_pretrained(
                self._model_id, token=self._access_token, torchscript=True
            )
            model = _DistilBertWrapper.try_wrapping(model)
            return _TraceableTextSimilarityModel(self._tokenizer, model)

        elif self._task_type == "pass_through":
            model = transformers.AutoModel.from_pretrained(
                self._model_id, token=self._access_token, torchscript=True
            )
            return _TraceablePassThroughModel(self._tokenizer, model)

        else:
            raise TypeError(
                f"Task {self._task_type} is not supported, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
            )