# Multi-lingual ASR Transcription on IPUs using Whisper - Fine-tuning

This notebook demonstrates fine-tuning for multi-lingual speech transcription on the IPU using the [Whisper implementation in the ðŸ¤— Transformers library](https://huggingface.co/spaces/openai/whisper) alongside [Optimum Graphcore](https://github.com/huggingface/optimum-graphcore). We will be using the Catalan subset of the [OpenSLR dataset](https://huggingface.co/datasets/openslr).

Whisper is a versatile speech recognition model that can transcribe speech as well as perform multi-lingual translation and recognition tasks.
It was trained on diverse datasets to give human-level speech recognition performance without the need for fine-tuning. 

[ðŸ¤— Optimum Graphcore](https://github.com/huggingface/optimum-graphcore) is the interface between the [ðŸ¤— Transformers library](https://huggingface.co/docs/transformers/index) and [Graphcore IPUs](https://www.graphcore.ai/products/ipu).
It provides a set of tools that enables model parallelization, loading on IPUs, training and fine-tuning on all the tasks already supported by Transformers. Optimum Graphcore is also compatible with the ðŸ¤— Hub and every model available on it out of the box.

|  Domain | Tasks | Model | Datasets | Workflow |   Number of IPUs   | Execution time |
|---------|-------|-------|----------|----------|--------------|--------------|
| Automatic Speech Recognition | Transcription | whisper-small | OpenSLR (SLR69) | Fine-tuning | 4 or 16 | 33 mins total (18 mins training, or 6 mins on POD16) |

[![Join our Slack Community](https://img.shields.io/badge/Slack-Join%20Graphcore's%20Community-blue?style=flat-square&logo=slack)](https://www.graphcore.ai/join-community)

## Environment setup

The best way to run this demo is on Paperspace Gradient's cloud IPUs because everything is already set up for you.

To run the demo using other IPU hardware, you need to have the Poplar SDK enabled. Refer to the [Getting Started guide](https://docs.graphcore.ai/en/latest/getting-started.html#getting-started) for your system for details on how to enable the Poplar SDK. Also refer to the [Jupyter Quick Start guide](https://docs.graphcore.ai/projects/jupyter-notebook-quick-start/en/latest/index.html) for how to set up Jupyter to be able to run this notebook on a remote IPU machine.

## Dependencies and imports

In order to improve usability and support for future users, Graphcore would like to collect information about the
applications and code being run in this notebook. The following information will be anonymised before being sent to Graphcore:

- User progression through the notebook
- Notebook details: number of cells, code being run and the output of the cells
- Environment details

You can disable logging at any time by running `%unload_ext graphcore_cloud_tools.notebook_logging.gc_logger` from any cell.

Install the dependencies the notebook needs.

In [None]:
# Install optimum-graphcore from source 
!pip install "optimum-graphcore==0.7.1" "soundfile" "librosa" "evaluate" "jiwer"
%pip install "graphcore-cloud-tools[logger] @ git+https://github.com/graphcore/graphcore-cloud-tools"
%load_ext graphcore_cloud_tools.notebook_logging.gc_logger

In [None]:
import os

n_ipu = int(os.getenv("NUM_AVAILABLE_IPU", 4))
executable_cache_dir = os.getenv("POPLAR_EXECUTABLE_CACHE_DIR", "/tmp/exe_cache/") + "/whisper"

In [None]:
# Generic imports
from dataclasses import dataclass
from typing import Any, Dict, List, Union

import evaluate
import numpy as np
import torch
from datasets import load_dataset, Audio, Dataset, DatasetDict

# IPU-specific imports
from optimum.graphcore import (
    IPUConfig, 
    IPUSeq2SeqTrainer, 
    IPUSeq2SeqTrainingArguments, 
)
from optimum.graphcore.models.whisper import WhisperProcessorTorch

# HF-related imports
from transformers import WhisperForConditionalGeneration

## Load Dataset

Common Voice datasets consist of recordings of speakers reading text from Wikipedia in different languages. ðŸ¤— Datasets enables us to easily download and prepare the training and evaluation splits.

First, ensure you have accepted the terms of use on the ðŸ¤— Hub: [mozilla-foundation/common_voice_13_0](https://huggingface.co/datasets/mozilla-foundation/common_voice_13_0). Once you have accepted the terms, you will have full access to the dataset and be able to download the data locally.

In [None]:
dataset = DatasetDict()
split_dataset = Dataset.train_test_split(
    load_dataset("openslr", "SLR69", split="train", token=False), test_size=0.2, seed=0
)
dataset["train"] = split_dataset["train"]
dataset["eval"] = split_dataset["test"]
print(dataset)

The columns of interest are:
* `audio`: the raw audio samples
* `sentence`: the corresponding ground truth transcription. 

We drop the `path` column.

In [None]:
dataset = dataset.remove_columns(["path"])

Since Whisper was pre-trained on audio sampled at 16 kHz, we must ensure the Common Voice samples are downsampled accordingly.

In [None]:
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

## Prepare Dataset

We prepare the datasets by extracting features from the raw audio inputs and injecting labels which are simply transcriptions with some basic processing.

The feature extraction is provided by ðŸ¤— Transformers `WhisperFeatureExtractor`. To decode generated tokens into text after running the model, we will similarly require a tokenizer, `WhisperTokenizer`. Both of these are wrapped by an instance of `WhisperProcessor`.

In [None]:
MODEL_NAME = "openai/whisper-small"
LANGUAGE = "spanish"
TASK = "transcribe"
MAX_LENGTH = 224

processor = WhisperProcessorTorch.from_pretrained(MODEL_NAME, language=LANGUAGE, task=TASK)
processor.tokenizer.pad_token = processor.tokenizer.eos_token
processor.tokenizer.max_length = MAX_LENGTH
processor.tokenizer.set_prefix_tokens(language=LANGUAGE, task=TASK)

In [None]:
def prepare_dataset(batch, processor):
    inputs = processor.feature_extractor(
        raw_speech=batch["audio"]["array"],
        sampling_rate=batch["audio"]["sampling_rate"],
    )
    batch["input_features"] = inputs.input_features[0].astype(np.float16)

    transcription = batch["sentence"]
    batch["labels"] = processor.tokenizer(text=transcription).input_ids
    return batch

columns_to_remove = dataset.column_names["train"]
dataset = dataset.map(
    lambda elem: prepare_dataset(elem, processor),
    remove_columns=columns_to_remove,
    num_proc=1,
)

train_dataset = dataset["train"]
eval_dataset = dataset["eval"]

Lastly, we pre-process the labels by padding them with values that will be ignored during fine-tuning. This padding is to ensure tensors of static shape are provided to the model. We do this on the fly via the data collator below.

In [None]:
@dataclass
class DataCollatorSpeechSeq2SeqWithLabelProcessing:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        batch = {}
        batch["input_features"] = torch.tensor([feature["input_features"] for feature in features])
        
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt", padding="longest", pad_to_multiple_of=MAX_LENGTH)
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

## Define metrics

The performance of our fine-tuned model will be evaluated using word error rate (WER).

In [None]:
metric = evaluate.load("wer")


def compute_metrics(pred, tokenizer):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    pred_ids = np.where(pred_ids != -100, pred_ids, tokenizer.pad_token_id)
    label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    normalized_pred_str = [tokenizer._normalize(pred).strip() for pred in pred_str]
    normalized_label_str = [tokenizer._normalize(label).strip() for label in label_str]

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)
    normalized_wer = 100 * metric.compute(predictions=normalized_pred_str, references=normalized_label_str)

    return {"wer": wer, "normalized_wer": normalized_wer}

## Load pre-trained model

In [None]:
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME)

In [None]:
model.config.max_length = MAX_LENGTH
model.generation_config.max_length = MAX_LENGTH

Ensure language-appropriate tokens, if any, are set for generation. We set them on both the `config` and the `generation_config` to ensure they are used correctly during generation.

In [None]:
model.config.forced_decoder_ids = processor.tokenizer.get_decoder_prompt_ids(
    language=LANGUAGE, task=TASK
)
model.config.suppress_tokens = []
model.generation_config.forced_decoder_ids = processor.tokenizer.get_decoder_prompt_ids(
    language=LANGUAGE, task=TASK
)
model.generation_config.suppress_tokens = []

## Fine-tuning Whisper on the IPU

The model can be directly fine-tuned on the IPU using the `IPUSeq2SeqTrainer` class. 

The `IPUConfig` object specifies how the model will be pipelined across the IPUs. 

For fine-tuning, we place the encoder on two IPUs, and the decoder on two IPUs.

For inference, the encoder is placed on one IPU, and the decoder on a different IPU.

In [None]:
replication_factor = n_ipu // 4
ipu_config = IPUConfig.from_dict(
    {
        "optimizer_state_offchip": True,
        "recompute_checkpoint_every_layer": True,
        "enable_half_partials": True,
        "executable_cache_dir": executable_cache_dir,
        "gradient_accumulation_steps": 16,
        "replication_factor": replication_factor,
        "layers_per_ipu": [5, 7, 5, 7],
        "matmul_proportion": [0.2, 0.2, 0.6, 0.6],
        "projection_serialization_factor": 5,
        "inference_replication_factor": 1,
        "inference_layers_per_ipu": [12, 12],
        "inference_parallelize_kwargs": {
            "use_cache": True,
            "use_encoder_output_buffer": True,
            "on_device_generation_steps": 16,
        }
    }
)

Lastly, we specify the arguments controlling the training process.

In [None]:
total_steps = 1000 // replication_factor
training_args = IPUSeq2SeqTrainingArguments(
    output_dir="./whisper-small-ipu-checkpoints",
    do_train=True,
    do_eval=True,
    predict_with_generate=True,
    learning_rate=1e-5 * replication_factor,
    warmup_steps=total_steps // 4,
    evaluation_strategy="steps",
    eval_steps=total_steps,
    max_steps=total_steps,
    save_strategy="steps",
    save_steps=total_steps,
    logging_steps=25,
    dataloader_num_workers=16,
    dataloader_drop_last=True,
)

Then, we just need to pass all of this together with our datasets to the `IPUSeq2SeqTrainer` class:

In [None]:
trainer = IPUSeq2SeqTrainer(
    model=model,
    ipu_config=ipu_config,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=DataCollatorSpeechSeq2SeqWithLabelProcessing(processor),
    compute_metrics=lambda x: compute_metrics(x, processor.tokenizer),
    tokenizer=processor.feature_extractor,
)

To gauge the improvement in WER, we run an evaluation step before fine-tuning.

In [None]:
trainer.evaluate()

All that remains is to fine-tune the model! The fine-tuning process should take between 6 and 18 minutes, depending on how many replicas are used, and achieve a final WER of around 10%.

In [None]:
trainer.train()

## Conclusion

In this notebook, we demonstrated how to fine-tune Whisper for multi-lingual speech recognition and transcription on the IPU. We used a single replica on a total of four IPUs. To reduce the fine-tuning time, more than one replica, hence more IPUs are required. On Paperspace, you can use either an IPU-POD16 or a Bow-POD16, both with 16 IPUs. Please contact Graphcore if you need assistance running on larger platforms.

For all available notebooks, check [IPU-powered Jupyter Notebooks](https://www.graphcore.ai/ipu-jupyter-notebooks) to see how IPUs perform on other tasks.

Have a question? Please contact us on our [Graphcore community channel](https://www.graphcore.ai/join-community).
