optimum/graphcore/pipelines/automatic_speech_recognition.py (91 lines of code) (raw):
# Copyright 2021 The HuggingFace Team. All rights reserved.
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from transformers import AutomaticSpeechRecognitionPipeline
from transformers.pipelines.base import (
DataLoader,
PipelineChunkIterator,
PipelineIterator,
PipelinePackIterator,
no_collate_fn,
pad_collate_fn,
)
from transformers.utils import logging
logger = logging.get_logger(__name__)
class IPUAutomaticSpeechRecognitionPipeline(AutomaticSpeechRecognitionPipeline):
def get_iterator(
self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params
):
if "TOKENIZERS_PARALLELISM" not in os.environ:
logger.info("Disabling tokenizer parallelism, we're using DataLoader multithreading already")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if num_workers > 1:
logger.warning(
"For ChunkPipeline using num_workers>0 is likely to result in errors since everything is iterable,"
" setting `num_workers=1` to guarantee correctness."
)
num_workers = 1
dataset = PipelineChunkIterator(inputs, self.preprocess, preprocess_params)
collate_fn = no_collate_fn if batch_size == 1 else pad_collate_fn(self.tokenizer, self.feature_extractor)
dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn)
# Change: If the last batch contains fewer than `batch_size` elements, pad it.
def batch_padding(items, batch_size):
is_last = items["is_last"]
actual_batch_size = 1 if isinstance(is_last, bool) else len(is_last)
if actual_batch_size >= batch_size:
return items
n_to_pad = batch_size - actual_batch_size
is_last = is_last + [None] * n_to_pad
# Pad input features by duplicating with genuine feature values as opposed to
# e.g. zeros. This makes it significantly more likely beam search will terminate.
input_features = items["input_features"]
new_input_features = input_features.repeat(
batch_size // actual_batch_size + 1, *([1] * (input_features.ndim - 1))
)
new_input_features = new_input_features[:batch_size]
padded_items = {"is_last": is_last, "input_features": new_input_features}
stride = items.get("stride", None)
if stride is not None:
stride = stride + [stride[-1]] * n_to_pad
padded_items["stride"] = stride
return padded_items
if self.type == "seq2seq_whisper":
dataloader = PipelineIterator(dataloader, batch_padding, {"batch_size": batch_size})
model_iterator = PipelinePackIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size)
final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
return final_iterator
def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
if not self.type == "seq2seq_whisper":
return super()._forward(model_inputs, return_timestamps=return_timestamps, generate_kwargs=generate_kwargs)
if generate_kwargs is None:
generate_kwargs = {}
if return_timestamps:
generate_kwargs["return_timestamps"] = return_timestamps
is_last = model_inputs.pop("is_last")
if "input_features" in model_inputs:
inputs = model_inputs.pop("input_features")
elif "input_values" in model_inputs:
inputs = model_inputs.pop("input_values")
else:
raise ValueError(
"Seq2Seq speech recognition model requires either a "
f"`input_features` or `input_values` key, but only has {model_inputs.keys()}"
)
attention_mask = model_inputs.pop("attention_mask", None)
tokens = self.model.generate(inputs=inputs, attention_mask=attention_mask, **generate_kwargs)
out = {"tokens": tokens}
stride = model_inputs.pop("stride", None)
if stride is not None:
out["stride"] = stride
extra = model_inputs
maybe_padded_ret = {"is_last": is_last, **out, **extra}
# Remove inputs and outputs associated with padded inputs.
if not isinstance(is_last, list):
is_last = [is_last]
first_padding_idx = tokens.shape[0]
for idx, last in enumerate(is_last):
if last is None:
first_padding_idx = idx
break
if first_padding_idx == tokens.shape[0]:
return maybe_padded_ret
padded_keys = ["is_last", "tokens"]
if stride is not None:
padded_keys.append("stride")
for padded_key in padded_keys:
maybe_padded_ret[padded_key] = maybe_padded_ret[padded_key][:first_padding_idx]
return maybe_padded_ret