optimum/graphcore/pipelines/__init__.py (400 lines of code) (raw):
# Copyright 2018 The HuggingFace Inc. team.
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Union
import poptorch
import torch
import transformers.pipelines
from peft import PeftModel
from transformers import (
AudioClassificationPipeline,
AutoFeatureExtractor,
AutoModelForAudioClassification,
AutoModelForCausalLM,
AutoModelForCTC,
AutoModelForImageClassification,
AutoModelForMaskedLM,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq,
AutoModelForTokenClassification,
AutoTokenizer,
ImageClassificationPipeline,
Pipeline,
PreTrainedTokenizer,
QuestionAnsweringPipeline,
TextClassificationPipeline,
TextGenerationPipeline,
WhisperForConditionalGeneration,
)
from transformers.feature_extraction_utils import PreTrainedFeatureExtractor
from transformers.modeling_utils import PreTrainedModel
from transformers.pipelines import get_task
from transformers.utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
from optimum.graphcore.generation.utils import MODELS_SUPPORTING_KV_CACHE, IPUGenerationMixin
from optimum.graphcore.ipu_configuration import IncompatibleIPUConfigError, IPUConfig
from optimum.graphcore.modeling_utils import to_pipelined
from .automatic_speech_recognition import IPUAutomaticSpeechRecognitionPipeline
from .fill_mask import IPUFillMaskPipeline
from .text2text_generation import IPUSummarizationPipeline, IPUText2TextGenerationPipeline, IPUTranslationPipeline
from .token_classification import IPUTokenClassificationPipeline
from .zero_shot_classification import IPUZeroShotClassificationPipeline
logger = logging.get_logger(__name__)
TASK_ALIASES = {
"sentiment-analysis": "text-classification",
"ner": "token-classification",
}
SUPPORTED_TASKS = {
"audio-classification": {
"impl": AudioClassificationPipeline,
"class": (AutoModelForAudioClassification,),
"default": {
"model": ("superb/hubert-base-superb-ks", "d7e0efe"),
"ipu_config": "Graphcore/hubert-base-ipu",
},
"type": "audio",
},
"automatic-speech-recognition": {
"impl": IPUAutomaticSpeechRecognitionPipeline,
"class": (AutoModelForCTC, AutoModelForSpeechSeq2Seq),
"default": {
"model": ("facebook/wav2vec2-base-960h", "55bb623"),
"ipu_config": "Graphcore/wav2vec2-ctc-base-ipu",
},
"type": "multimodal",
},
"fill-mask": {
"impl": IPUFillMaskPipeline,
"class": (AutoModelForMaskedLM,),
"default": {
"model": ("distilroberta-base", "ec58a5b"),
"ipu_config": "Graphcore/distilroberta-base-ipu",
"max_length": 128,
},
"type": "text",
},
"image-classification": {
"impl": ImageClassificationPipeline,
"class": (AutoModelForImageClassification,),
"default": {
"model": ("google/vit-base-patch16-224", "5dca96d"),
"ipu_config": "Graphcore/vit-base-ipu",
},
"type": "image",
},
"question-answering": {
"impl": QuestionAnsweringPipeline,
"class": (AutoModelForQuestionAnswering,),
"default": {
"model": ("distilbert-base-cased-distilled-squad", "626af31"),
"ipu_config": "Graphcore/distilbert-base-ipu",
},
"type": "text",
},
"text-classification": {
"impl": TextClassificationPipeline,
"class": (AutoModelForSequenceClassification,),
"default": {
"model": ("distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"),
"ipu_config": "Graphcore/distilbert-base-ipu",
"max_length": 128,
},
"type": "text",
},
"token-classification": {
"impl": IPUTokenClassificationPipeline,
"class": (AutoModelForTokenClassification,),
"default": {
"model": ("dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"),
"ipu_config": "Graphcore/bert-large-ipu",
"max_length": 128,
},
"type": "text",
},
"text-generation": {
"impl": TextGenerationPipeline,
"class": (AutoModelForCausalLM,),
"default": {
"model": ("gpt2", "e7da7f2"),
"ipu_config": IPUConfig(),
"max_length": 50,
},
"type": "text",
},
"summarization": {
"impl": IPUSummarizationPipeline,
"class": (AutoModelForSeq2SeqLM,),
"default": {
"model": ("ainize/bart-base-cnn", "b90bc9a"),
"ipu_config": IPUConfig(ipus_per_replica=2),
"max_input_length": 50,
"max_length": 20,
"truncation": "only_first",
},
"type": "text",
},
# This task is a special case as it's parametrized by SRC, TGT languages.
"translation": {
"impl": IPUTranslationPipeline,
"class": (AutoModelForSeq2SeqLM,),
"default": {
"model": ("t5-small", "9507060"),
"ipu_config": IPUConfig(ipus_per_replica=2),
"max_length": 50,
"max_input_length": 45,
"truncation": "only_first",
},
"type": "text",
},
"text2text-generation": {
"impl": IPUText2TextGenerationPipeline,
"class": (AutoModelForSeq2SeqLM,),
"default": {
"model": ("t5-small", "9507060"),
"ipu_config": IPUConfig(ipus_per_replica=2),
"max_length": 50,
"max_input_length": 50,
"truncation": "only_first",
},
"type": "text",
},
"zero-shot-classification": {
"impl": IPUZeroShotClassificationPipeline,
"class": (AutoModelForSequenceClassification,),
"default": {
"model": ("roberta-large-mnli", "130fb28"),
"ipu_config": "Graphcore/roberta-large-ipu",
"max_length": 128,
},
"type": "text",
},
}
SUPPORTED_GENERATION_TASKS = {
"summarization",
"text-generation",
"text2text-generation",
"translation",
}
SUPPORTED_SEQ2SEQ_GENERATION_TASKS = {"summarization", "text2text-generation", "translation"}
NO_FEATURE_EXTRACTOR_TASKS = set()
NO_TOKENIZER_TASKS = set()
for task, values in SUPPORTED_TASKS.items():
if values["type"] == "text":
NO_FEATURE_EXTRACTOR_TASKS.add(task)
elif values["type"] in {"audio", "image"}:
NO_TOKENIZER_TASKS.add(task)
elif values["type"] != "multimodal":
raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}")
def list_tasks() -> List[str]:
"""Lists the supported tasks and their aliases"""
return sorted([*{*SUPPORTED_TASKS, *TASK_ALIASES}])
def get_poplar_executor(
task: str,
model: PreTrainedModel,
ipu_config: Union[IPUConfig, str, dict] = None,
fp16: bool = True,
for_generation: bool = False,
) -> PreTrainedModel:
ipu_config_arg = ipu_config
if isinstance(ipu_config, str):
ipu_config = IPUConfig.from_pretrained(ipu_config)
elif isinstance(ipu_config, dict):
ipu_config = IPUConfig.from_dict(ipu_config)
elif not isinstance(ipu_config, IPUConfig):
raise ValueError("ipu_config must be an IPUConfig, string, or a dictionary.")
# So that IPUConfig returns inference versions of any parameters
# that are different in training and inference
ipu_config.eval()
parallelize_kwargs = ipu_config.inference_parallelize_kwargs
ipu_config.inference_device_iterations = 1
if not parallelize_kwargs.get("use_cond_encoder", False):
ipu_config.inference_replication_factor = 1
if not fp16:
ipu_config.enable_half_partials = False
try:
model = to_pipelined(model, ipu_config, force=False)
if model.config.is_encoder_decoder and isinstance(model, IPUGenerationMixin):
if "use_cache" not in parallelize_kwargs and model.__class__ in MODELS_SUPPORTING_KV_CACHE:
parallelize_kwargs["use_cache"] = True
model.parallelize(for_generation=for_generation, **parallelize_kwargs)
else:
model.parallelize(**parallelize_kwargs)
except Exception as error:
new_message = (
"The model and ipu_config seem to be incompatible,"
" please try a different IPU config or customize it for the model."
f" The config provided is '{ipu_config_arg}'\n"
f"{error}"
)
raise IncompatibleIPUConfigError(new_message) from error
if fp16:
model.half()
opts = ipu_config.to_options(for_inference=True)
opts.setExecutionStrategy(poptorch.ShardedExecution())
# Text generation models have an internal Poplar executor so don't wrap model in that case
if not for_generation:
model = poptorch.inferenceModel(model.eval(), opts)
return model
def check_model_type(self, supported_models: Union[List[str], dict]):
"""
Check if the model class is supported by the pipeline.
Args:
supported_models (`List[str]` or `dict`):
The list of models supported by the pipeline, or a dictionary with model class values.
"""
if not isinstance(supported_models, list): # Create from a model mapping
supported_models_names = []
for config, model in supported_models.items():
# Mapping can now contain tuples of models for the same configuration.
if isinstance(model, tuple):
supported_models_names.extend([_model.__name__ for _model in model])
else:
supported_models_names.append(model.__name__)
supported_models = supported_models_names
if isinstance(self.model, poptorch.PoplarExecutor):
model_class_name = self.model._user_model.__class__.__bases__[0].__name__
elif isinstance(self.model, IPUGenerationMixin):
model_class_name = self.model.__class__.__bases__[0].__name__
else:
model_class_name = self.model.__class__.__name__
if model_class_name not in supported_models:
logger.error(
f"The model '{model_class_name}' is not supported for {self.task}. Supported models are"
f" {supported_models}."
)
def pipeline(
task: str = None,
model: Optional[Any] = None,
ipu_config: Union[IPUConfig, str, dict] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
revision: Optional[str] = None,
use_auth_token: Optional[Union[str, bool]] = None,
pipeline_class: Optional[Any] = None,
fp16: bool = True,
parallelize_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Pipeline:
"""Utility factory method to build a [ Pipeline ] for IPU models.
Arguments:
task : The task, see docs for ``transformers.pipeline`` for supported options.
model : A pre-trained model, see docs for ``transformers.pipeline`` for supported options.
ipu_config : An IPU config, can either be the path to a model from the HuggingFace Hub
which defines a ``ipu_config.json`` or a dictionary with the same options.
tokenizer : The tokenizer, see docs for ``transformers.pipeline`` for supported options.
feature_extractor : The feature extractor, see docs for ``transformers.pipeline`` for supported options.
revision : Revision of the model.
use_auth_token : An authorization token to use for calls to the Hub.
pipeline_class : Override the `Pipeline` class defined by the task.
fp16 : If `True`, uses float16.
**kwargs: Additional keyword arguments that are passed to the ``transformers.pipeline`` function
Returns:
The pipeline object for the specified task.
"""
if task is None and model is None:
raise RuntimeError(
"Impossible to instantiate a pipeline without either a task or a model "
"being specified. "
"Please provide a task class or a model"
)
if task is None and model is not None:
if not isinstance(model, str):
raise RuntimeError(
"Inferring the task automatically requires to check the Hub with a model_id defined as a `str`."
f"{model} is not a valid model_id."
)
task = get_task(model, use_auth_token)
if task in TASK_ALIASES:
task = TASK_ALIASES[task]
targeted_task = "translation" if task.startswith("translation") else task
if targeted_task not in SUPPORTED_TASKS:
raise ValueError(f"Task {targeted_task} is not supported. Supported tasks are {list(SUPPORTED_TASKS.keys())}")
# These will never require a tokenizer.
# the model on the other hand might have a tokenizer, but
# the files could be missing from the hub, instead of failing
# on such repos, we just force to not load it.
load_tokenizer = targeted_task not in NO_TOKENIZER_TASKS
load_feature_extractor = targeted_task not in NO_FEATURE_EXTRACTOR_TASKS
if model is None:
model_id, revision = SUPPORTED_TASKS[targeted_task]["default"]["model"]
logger.warning(
f"No model was supplied, defaulted to {model_id} and revision"
f" {revision} ({HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{model_id}).\n"
"Using a pipeline without specifying a model name and revision in production is not recommended."
)
model = SUPPORTED_TASKS[targeted_task]["class"][0].from_pretrained(model_id, revision=revision)
elif isinstance(model, str):
model_id = model
for cl in SUPPORTED_TASKS[targeted_task]["class"]:
try:
model = cl.from_pretrained(model_id, revision=revision)
break
except ValueError:
continue
elif isinstance(model, PeftModel):
raise TypeError(
"Instead of providing `model` as an instance of `PeftModel`, please call `merge_and_unload()` if LoRA "
"or equivalent to obtain the original `PreTrainedModel` back with adapter weights merged in."
)
elif isinstance(model, PreTrainedModel):
if tokenizer is None and load_tokenizer:
raise ValueError("If you pass a model as a PreTrainedModel, you must pass a tokenizer as well")
if feature_extractor is None and load_feature_extractor:
raise ValueError("If you pass a model as a PreTrainedModel, you must pass a feature extractor as well")
for_generation = targeted_task in SUPPORTED_GENERATION_TASKS
if isinstance(model, PreTrainedModel):
if ipu_config is None:
ipu_config = SUPPORTED_TASKS[targeted_task]["default"]["ipu_config"]
parallelize_kwargs = parallelize_kwargs or {}
# Task of automatic speech recognition is a bit of an edge case where it separates into CTC (not generation) and seq2seq (generation).
# This check will do for now.
for_generation |= isinstance(model, WhisperForConditionalGeneration)
model = get_poplar_executor(
targeted_task, model, ipu_config=ipu_config, fp16=fp16, for_generation=for_generation, **parallelize_kwargs
)
elif isinstance(model, poptorch._poplar_executor.PoplarExecutor):
if tokenizer is None and load_tokenizer:
raise ValueError(
"If you pass a model as a poptorch._poplar_executor.PoplarExecutor, you must pass a tokenizer as well"
)
if feature_extractor is None and load_feature_extractor:
raise ValueError(
"If you pass a model as a poptorch._poplar_executor.PoplarExecutor, you must pass a feature extractor as well"
)
else:
raise ValueError(
f"""Model {model} is not supported. Please provide a valid model either as string, PreTrainedModel or
poptorch._poplar_executor.PoplarExecutor. If you don't provide a model, a default model will be used."""
)
# Upstream pipeline creation does not easily support loading these when an actual model
# is provided, so we load them here.
if tokenizer is None and load_tokenizer:
tokenizer = AutoTokenizer.from_pretrained(model_id)
if feature_extractor is None and load_feature_extractor:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
# Override Pipeline methods
Pipeline.check_model_type = check_model_type
if pipeline_class is None:
pipeline_class = SUPPORTED_TASKS[targeted_task]["impl"]
# Override pipelines' _forward
old_forward = pipeline_class._forward
def new_forward(self, model_inputs, *args, **kwargs):
if isinstance(self.model, poptorch.PoplarExecutor) and not for_generation:
# For non-text generation models, support batch size changes.
poplar_executor = self.model
if poplar_executor._executable_inputs:
for arg in poplar_executor._executable_inputs.args:
if isinstance(arg, torch.Tensor):
compiled_bs = arg.shape[0]
break
for input in model_inputs.values():
if isinstance(input, torch.Tensor):
input_bs = input.shape[0]
break
if compiled_bs != input_bs:
poplar_executor.destroy()
if isinstance(self.model, poptorch.PoplarExecutor) or isinstance(self.model, IPUGenerationMixin):
if fp16:
# Support fp16
for key, input in model_inputs.items():
if isinstance(input, torch.Tensor) and input.dtype == torch.float32:
model_inputs[key] = input.half()
return old_forward(self, model_inputs, *args, **kwargs)
pipeline_class._forward = new_forward
# Implement pipelines __del__ to clean up poplar exector
def _del(self):
# For text generation models, deallocate the internal poplar executors
if hasattr(self.model, "poptorch_decoder"):
self.model.poptorch_decoder.destroy()
if hasattr(self.model, "poptorch_encoder"):
self.model.poptorch_encoder.destroy()
pipeline_class.__del__ = _del
# Auto padding for some tasks
if "max_length" in SUPPORTED_TASKS[targeted_task]["default"]:
default_max_length = SUPPORTED_TASKS[targeted_task]["default"]["max_length"]
if not for_generation:
kwargs["padding"] = kwargs.get("padding", "max_length")
if kwargs.get("max_length") is None:
logger.warning(
f"No padding arguments specified, so padding to {default_max_length} by default. "
f"Inputs longer than {default_max_length} will be truncated."
" To change this behaviour, pass the `padding='max_length'` and"
"`max_length=<your desired input length>` arguments to the pipeline function."
)
kwargs["max_length"] = kwargs.get("max_length", default_max_length)
if targeted_task in SUPPORTED_SEQ2SEQ_GENERATION_TASKS:
default_max_input_length = SUPPORTED_TASKS[targeted_task]["default"]["max_input_length"]
kwargs["max_input_length"] = kwargs.get("max_input_length", default_max_input_length)
default_truncation = SUPPORTED_TASKS[targeted_task]["default"]["truncation"]
kwargs["truncation"] = kwargs.get("truncation", default_truncation)
# question-answering already has its own default padding length `max_seq_len` defined, so we just enable padding to max length.
if targeted_task in {"question-answering"}:
kwargs["padding"] = kwargs.get("padding", "max_length")
logger.warning(
"No padding arguments specified, so padding to 384 by default. Inputs longer than 384 will be truncated."
)
# Set pad_token for models that do not have pad_token
if model.config.model_type in {"gpt2"}:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
return transformers.pipelines.pipeline(
task,
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
use_auth_token=use_auth_token,
pipeline_class=pipeline_class,
**kwargs,
)