in eland/ml/pytorch/transformers.py [0:0]
def _compatible_inputs(self) -> Tuple[Tensor, ...]:
inputs = self._prepare_inputs()
# Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface
if "token_type_ids" not in inputs:
inputs["token_type_ids"] = torch.zeros(
inputs["input_ids"].size(1), dtype=torch.long
)
if isinstance(
self._tokenizer,
(
transformers.BartTokenizer,
transformers.MPNetTokenizer,
transformers.RobertaTokenizer,
transformers.XLMRobertaTokenizer,
),
):
return (inputs["input_ids"], inputs["attention_mask"])
if isinstance(self._tokenizer, transformers.DebertaV2Tokenizer):
return (
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
)
position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long)
inputs["position_ids"] = position_ids
return (
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["position_ids"],
)