in eland/ml/pytorch/transformers.py [0:0]
def _create_traceable_model(self) -> _TransformerTraceableModel:
if self._task_type == "auto":
model = transformers.AutoModel.from_pretrained(
self._model_id, token=self._access_token, torchscript=True
)
maybe_task_type = task_type_from_model_config(model.config)
if maybe_task_type is None:
raise TaskTypeError(
f"Unable to automatically determine task type for model {self._model_id}, please supply task type: {SUPPORTED_TASK_TYPES_NAMES}"
)
else:
self._task_type = maybe_task_type
if self._task_type == "text_expansion":
model = transformers.AutoModelForMaskedLM.from_pretrained(
self._model_id, token=self._access_token, torchscript=True
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableTextExpansionModel(self._tokenizer, model)
if self._task_type == "fill_mask":
model = transformers.AutoModelForMaskedLM.from_pretrained(
self._model_id, token=self._access_token, torchscript=True
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableFillMaskModel(self._tokenizer, model)
elif self._task_type == "ner":
model = transformers.AutoModelForTokenClassification.from_pretrained(
self._model_id, token=self._access_token, torchscript=True
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableNerModel(self._tokenizer, model)
elif self._task_type == "text_classification":
model = transformers.AutoModelForSequenceClassification.from_pretrained(
self._model_id, token=self._access_token, torchscript=True
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableTextClassificationModel(self._tokenizer, model)
elif self._task_type == "text_embedding":
model = _DPREncoderWrapper.from_pretrained(
self._model_id, token=self._access_token
)
if not model:
model = _SentenceTransformerWrapperModule.from_pretrained(
self._model_id, self._tokenizer, token=self._access_token
)
return _TraceableTextEmbeddingModel(self._tokenizer, model)
elif self._task_type == "zero_shot_classification":
model = transformers.AutoModelForSequenceClassification.from_pretrained(
self._model_id, token=self._access_token, torchscript=True
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableZeroShotClassificationModel(self._tokenizer, model)
elif self._task_type == "question_answering":
model = _QuestionAnsweringWrapperModule.from_pretrained(
self._model_id, token=self._access_token
)
return _TraceableQuestionAnsweringModel(self._tokenizer, model)
elif self._task_type == "text_similarity":
model = transformers.AutoModelForSequenceClassification.from_pretrained(
self._model_id, token=self._access_token, torchscript=True
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableTextSimilarityModel(self._tokenizer, model)
elif self._task_type == "pass_through":
model = transformers.AutoModel.from_pretrained(
self._model_id, token=self._access_token, torchscript=True
)
return _TraceablePassThroughModel(self._tokenizer, model)
else:
raise TypeError(
f"Task {self._task_type} is not supported, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
)