optimum/graphcore/pipelines/text2text_generation.py (69 lines of code) (raw):
# Copyright 2021 The HuggingFace Team. All rights reserved.
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline
from transformers.pipelines.text2text_generation import TruncationStrategy
class IPUText2TextGenerationPipeline(Text2TextGenerationPipeline):
def _sanitize_parameters(
self,
return_tensors=None,
return_text=None,
return_type=None,
clean_up_tokenization_spaces=None,
truncation=None,
stop_sequence=None,
max_input_length=None,
**generate_kwargs,
):
preprocess_params, forward_params, postprocess_params = super()._sanitize_parameters(
return_tensors,
return_text,
return_type,
clean_up_tokenization_spaces,
truncation,
stop_sequence,
**generate_kwargs,
)
if max_input_length is not None:
preprocess_params["max_input_length"] = max_input_length
return preprocess_params, forward_params, postprocess_params
def _parse_and_tokenize(self, *args, truncation, **kwargs):
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
if isinstance(args[0], list):
if self.tokenizer.pad_token_id is None:
raise ValueError("Please make sure that the tokenizer has a pad_token_id when using a batch input")
args = ([prefix + arg for arg in args[0]],)
padding = True
elif isinstance(args[0], str):
args = (prefix + args[0],)
padding = False
else:
raise ValueError(
f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`"
)
padding = "max_length"
inputs = self.tokenizer(
*args,
padding=padding,
max_length=kwargs.get("max_input_length"),
truncation=truncation,
return_tensors=self.framework,
)
# This is produced by tokenizers but is an invalid generate kwargs
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
return inputs
class IPUSummarizationPipeline(SummarizationPipeline, IPUText2TextGenerationPipeline):
pass
class IPUTranslationPipeline(TranslationPipeline, IPUText2TextGenerationPipeline):
def preprocess(
self, *args, truncation=TruncationStrategy.DO_NOT_TRUNCATE, src_lang=None, tgt_lang=None, max_input_length=None
):
if getattr(self.tokenizer, "_build_translation_inputs", None):
return self.tokenizer._build_translation_inputs(
*args,
return_tensors=self.framework,
max_length=max_input_length,
padding="max_length",
truncation=truncation,
src_lang=src_lang,
tgt_lang=tgt_lang,
)
else:
return super()._parse_and_tokenize(*args, truncation=truncation, max_input_length=max_input_length)