notebooks/wav2vec2/wav2vec2-fine-tuning-checkpoint.ipynb (803 lines of code) (raw):

{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "eb00796b", "metadata": {}, "source": [ "# Automatic Speech Recognition (ASR) on IPU using wav2vec - Fine-tuning\n", "\n", "This notebook will demonstrate how to fine-tune a pre-trained wav2vec 2.0 model with PyTorch on Graphcore IPUs. We will use a `wav2vec2-base` model and fine-tune it for a CTC downstream task using the [LibriSpeech](https://huggingface.co/datasets/librispeech_asr) dataset.\n", "\n", "We will show how to use a wav2vec 2.0 model written in PyTorch from the [🤗 Transformers library](https://huggingface.co/docs/transformers/index) and parallelize it easily using the [🤗 Optimum Graphcore library](https://github.com/huggingface/optimum-graphcore).\n", "\n", "🤗 provides convenient access to pre-trained transformer models. The partnership between 🤗 and Graphcore allows us to run these models on the IPU.\n", "\n", "🤗 models ported to the IPU can be found on the [Graphcore Hugging Face organisation page](https://huggingface.co/Graphcore)." ] }, { "attachments": {}, "cell_type": "markdown", "id": "eaf275aa", "metadata": {}, "source": [ "| Domain | Tasks | Model | Datasets | Workflow | Number of IPUs | Execution time |\n", "|---------|-------|-------|----------|----------|--------------|--------------|\n", "| Audio processing | Automatic speech recognition | wav2vec 2.0 | LibriSpeech (librispeech_asr) | Training | 4 or 16 | 30min |" ] }, { "attachments": {}, "cell_type": "markdown", "id": "883c444e", "metadata": {}, "source": [ "[![Join our Slack Community](https://img.shields.io/badge/Slack-Join%20Graphcore's%20Community-blue?style=flat-square&logo=slack)](https://www.graphcore.ai/join-community)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "3fcc332a", "metadata": {}, "source": [ "## Background\n", "\n", "Automatic speech recognition (ASR), the task of transcribing audio automatically, has historically required large amounts of labelled data. Additionally, these systems had predominantly used fixed-feature extraction methods which do not learn from the raw signal, for example, using short-time Fourier transform, or Mel-frequency cepstrum coefficients. Research conducted by Facebook AI (now Meta AI) demonstrates a [framework for self-supervised learning for speech representations](https://arxiv.org/abs/2006.11477). In other words, they demonstrated a pre-training phase and architecture which can learn feature representations, and their relationships, by leveraging large amounts of unlabelled, raw audio data. \n", "\n", "There are two phases to training: pre-training on unlabelled data, and fine-tuning on a down-stream task. In the original literature the model is fine-tuned for connectionist temporal classification (CTC), which is an ASR task. The consistent modules between pre-training and fine-tuning are what you’d expect to see in a CTC system; it has feature extraction, and an encoder. But, unlike many models of the past, the feature extraction is a convolutional neural network, which makes it trainable. Following that, there is a BERT-style encoder where a large convolutional block is used before the first layers, rather than using sinusoidal positional encoding. " ] }, { "attachments": {}, "cell_type": "markdown", "id": "dc70faef", "metadata": {}, "source": [ "## Environment setup\n", "\n", "The best way to run this demo is on Paperspace Gradient's cloud IPUs because everything is already set up for you.\n", "\n", "[![Run on Gradient](https://assets.paperspace.io/img/gradient-badge.svg)](https://ipu.dev/3CGkbMq)\n", "\n", "To run the demo using other IPU hardware, you need to have the Poplar SDK enabled. Refer to the [Getting Started guide](https://docs.graphcore.ai/en/latest/getting-started.html#getting-started) for your system for details on how to enable the Poplar SDK. Also refer to the [Jupyter Quick Start guide](https://docs.graphcore.ai/projects/jupyter-notebook-quick-start/en/latest/index.html) for how to set up Jupyter to be able to run this notebook on a remote IPU machine." ] }, { "attachments": {}, "cell_type": "markdown", "id": "ed181f50", "metadata": {}, "source": [ "## Dependencies and configuration" ] }, { "attachments": {}, "cell_type": "markdown", "id": "23ba46a1", "metadata": {}, "source": [ "In order to improve usability and support for future users, Graphcore would like to collect information about the\n", "applications and code being run in this notebook. The following information will be anonymised before being sent to Graphcore:\n", "\n", "- User progression through the notebook\n", "- Notebook details: number of cells, code being run and the output of the cells\n", "- Environment details\n", "\n", "You can disable logging at any time by running `%unload_ext graphcore_cloud_tools.notebook_logging.gc_logger` from any cell." ] }, { "attachments": {}, "cell_type": "markdown", "id": "4625673f", "metadata": {}, "source": [ "Install the dependencies for this notebook." ] }, { "cell_type": "code", "execution_count": null, "id": "4ae7f3c3", "metadata": {}, "outputs": [], "source": [ "%%bash\n", "apt update\n", "apt-get install libsndfile1 -y" ] }, { "cell_type": "code", "execution_count": null, "id": "add16b1f", "metadata": {}, "outputs": [], "source": [ "%pip install -r requirements.txt\n", "%load_ext graphcore_cloud_tools.notebook_logging.gc_logger" ] }, { "attachments": {}, "cell_type": "markdown", "id": "673c83f7", "metadata": {}, "source": [ "Next import the utilities that will be used later in the notebook: " ] }, { "cell_type": "code", "execution_count": null, "id": "d3dae5e6", "metadata": {}, "outputs": [], "source": [ "import functools\n", "import json\n", "import logging\n", "import os\n", "import re\n", "import sys\n", "import warnings\n", "from dataclasses import dataclass, field\n", "from typing import Dict, List, Optional, Union\n", "\n", "import datasets\n", "import numpy as np\n", "import torch\n", "from datasets import DatasetDict, load_dataset, load_metric\n", "from pathlib import Path\n", "import transformers\n", "from optimum.graphcore import IPUConfig, IPUTrainer\n", "from optimum.graphcore import IPUTrainingArguments\n", "from transformers import (\n", " AutoConfig,\n", " AutoFeatureExtractor,\n", " AutoModelForCTC,\n", " AutoProcessor,\n", " AutoTokenizer,\n", " HfArgumentParser,\n", " Wav2Vec2Processor,\n", " set_seed,\n", ")\n", "from transformers.trainer_utils import get_last_checkpoint\n", "from transformers.utils import check_min_version\n", "from transformers.utils.versions import require_version" ] }, { "cell_type": "code", "execution_count": null, "id": "4a47cf47", "metadata": {}, "outputs": [], "source": [ "set_seed(0)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "c8313f3a", "metadata": {}, "source": [ "Values for machine size and cache directories can be configured through environment variables or directly in the notebook:" ] }, { "cell_type": "code", "execution_count": null, "id": "42e5ab85", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "n_ipu = int(os.getenv(\"NUM_AVAILABLE_IPU\", 4))\n", "executable_cache_dir = os.getenv(\"POPLAR_EXECUTABLE_CACHE_DIR\", \"/tmp/exe_cache/\") + \"/wav2vec2_fine_tuning\"\n", "checkpoint_directory = Path(os.getenv(\"CHECKPOINT_DIR\", \"/tmp\")) / \"demo\"" ] }, { "attachments": {}, "cell_type": "markdown", "id": "e3baf3ee", "metadata": {}, "source": [ "## Preparing the LibriSpeech dataset\n", "\n", "The [🤗 Datasets](https://huggingface.co/docs/datasets/index) library can be used to conveniently load the LibriSpeech dataset, and the library provides easy-to-use tools to process the data.\n", "\n", "First we create a `DatasetDict` dictionary to handle our data, and then load the LibriSpeech splits for training and validation. For this notebook we will use `train.100` which is 100 hours of clean training data. Section C of the appendix in the paper on the [framework for self-supervised learning for speech representations](https://arxiv.org/abs/2006.11477) suggests that fine-tuning a `Base` model can yield a 6.1% word error rate (WER) without an additional language model." ] }, { "cell_type": "code", "execution_count": null, "id": "72874ff7", "metadata": {}, "outputs": [], "source": [ "raw_datasets = DatasetDict()\n", "raw_datasets[\"train\"] = load_dataset(\"librispeech_asr\", \"clean\", split=\"train.100\")\n", "raw_datasets[\"eval\"] = load_dataset(\"librispeech_asr\", \"clean\", split=\"validation\")" ] }, { "cell_type": "code", "execution_count": null, "id": "7abca91e", "metadata": {}, "outputs": [], "source": [ "raw_datasets" ] }, { "attachments": {}, "cell_type": "markdown", "id": "539ac40f", "metadata": {}, "source": [ "### Text normalisation\n", "\n", "Using the package `map` function, any special characters are removed from the transcription. The resultant transcript is then lower-cased. These two processes mean that the model will not have to learn punctuation and capitalisation. Although the model may have the ability to learn capitalisation and punctuation, it will be easier if this is not required.\n", "\n", "There are other situations where text normalisation may be used like converting digits into their text counterpart. This is not performed in this script as LibriSpeech already has the text counterpart." ] }, { "cell_type": "code", "execution_count": null, "id": "28065436", "metadata": {}, "outputs": [], "source": [ "chars_to_ignore_regex = \"\".join([\",\", \"?\", \".\", \"!\", \"-\", \"\\;\", \"\\:\", \"\\\"\", \"“\", \"%\", \"‘\", \"”\", \"�\"])\n", "text_column_name = \"text\"\n", "\n", "\n", "def remove_special_characters(batch):\n", " if chars_to_ignore_regex is not None:\n", " batch[\"target_text\"] = re.sub(chars_to_ignore_regex, \"\", batch[text_column_name]).lower() + \" \"\n", " else:\n", " batch[\"target_text\"] = batch[text_column_name].lower() + \" \"\n", " return batch\n", "\n", "\n", "raw_datasets = raw_datasets.map(\n", " remove_special_characters,\n", " remove_columns=[text_column_name],\n", " desc=\"remove special characters from datasets\",\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "04c580de", "metadata": {}, "source": [ "### Create vocabulary and tokenizer\n", "\n", "We now create a vocabulary from the dataset. This will find all the unique characters from all the normalised text in the datasets." ] }, { "cell_type": "code", "execution_count": null, "id": "342e7e36", "metadata": {}, "outputs": [], "source": [ "def create_vocabulary_from_data(\n", " datasets: DatasetDict,\n", " word_delimiter_token=None,\n", " unk_token=None,\n", " pad_token=None,\n", "):\n", " # Given training and test labels create vocabulary\n", " def extract_all_chars(batch):\n", " all_text = \" \".join(batch[\"target_text\"])\n", " vocab = list(set(all_text))\n", " return {\"vocab\": [vocab], \"all_text\": [all_text]}\n", "\n", " vocabs = datasets.map(\n", " extract_all_chars,\n", " batched=True,\n", " batch_size=-1,\n", " keep_in_memory=True,\n", " remove_columns=datasets[\"train\"].column_names,\n", " )\n", "\n", " # take union of all unique characters in each dataset\n", " vocab_set = functools.reduce(\n", " lambda vocab_1, vocab_2: set(vocab_1[\"vocab\"][0]) | set(vocab_2[\"vocab\"][0]), vocabs.values()\n", " )\n", "\n", " vocab_dict = {v: k for k, v in enumerate(sorted(list(vocab_set)))}\n", "\n", " # replace white space with delimiter token\n", " if word_delimiter_token is not None:\n", " vocab_dict[word_delimiter_token] = vocab_dict[\" \"]\n", " del vocab_dict[\" \"]\n", "\n", " # add unk and pad token\n", " if unk_token is not None:\n", " vocab_dict[unk_token] = len(vocab_dict)\n", "\n", " if pad_token is not None:\n", " vocab_dict[pad_token] = len(vocab_dict)\n", "\n", " return vocab_dict\n", "\n", "\n", "word_delimiter_token = \"|\"\n", "unk_token = \"[UNK]\"\n", "pad_token = \"[PAD]\"\n", "\n", "vocab_dict = create_vocabulary_from_data(raw_datasets,\n", " word_delimiter_token=word_delimiter_token,\n", " unk_token=unk_token,\n", " pad_token=pad_token)" ] }, { "cell_type": "code", "execution_count": null, "id": "c4541263", "metadata": {}, "outputs": [], "source": [ "vocab_dict" ] }, { "attachments": {}, "cell_type": "markdown", "id": "20922bdc", "metadata": {}, "source": [ "With the vocabulary generated from the normalised transcripts, we create a tokenizer which is included in the [🤗 Transformers](https://huggingface.co/docs/transformers/index) library. This will later be used to encode text into indexes, and decode indexes into text." ] }, { "cell_type": "code", "execution_count": null, "id": "0d9fda29", "metadata": {}, "outputs": [], "source": [ "tokenizer_name_or_path = \"/tmp/wav2vec2-notebook\"\n", "\n", "vocab_file = os.path.join(tokenizer_name_or_path, \"vocab.json\")\n", "\n", "if os.path.isfile(vocab_file):\n", " os.remove(vocab_file)\n", "\n", "os.makedirs(tokenizer_name_or_path, exist_ok=True)\n", "\n", "with open(vocab_file, \"w\") as file:\n", " json.dump(vocab_dict, file)\n", "\n", "tokenizer_kwargs = {\n", " \"config\": None,\n", " \"tokenizer_type\": \"wav2vec2\",\n", " \"unk_token\": unk_token,\n", " \"pad_token\": pad_token,\n", " \"word_delimiter_token\": word_delimiter_token,\n", "}\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_auth_token=False, **tokenizer_kwargs)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "d94d06af", "metadata": {}, "source": [ "Let's look at an example for using the tokenizer. The vocabulary does not contain any digits, so these will be set to `[UNK]`. Remember, any special characters (such as commas) have already been removed from the dataset." ] }, { "cell_type": "code", "execution_count": null, "id": "a82868fc", "metadata": {}, "outputs": [], "source": [ "tokenizer(\"wav2vec2 finetuning on ipu\")" ] }, { "cell_type": "code", "execution_count": null, "id": "bdae8729", "metadata": {}, "outputs": [], "source": [ "tokenizer.decode(tokenizer(\"wav2vec2 finetuning on ipu\").input_ids)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "f98bbe23", "metadata": {}, "source": [ "### Feature extraction\n", "\n", "Now we generate the feature extraction method for the model and map it across the datasets onto the audio data. In this model we are learning from a raw audio signal, so the feature extraction is just used to resample the audio to the rate which the model expects. \n", "\n", "Afterwards we set the minimum and maximum input lengths in samples. These are set to 2.0 and 15.6 seconds, converted to 32000 and 249600 samples respectively for a sampling frequency of 16kHz. " ] }, { "cell_type": "code", "execution_count": null, "id": "e53f1437", "metadata": {}, "outputs": [], "source": [ "feature_extractor = AutoFeatureExtractor.from_pretrained(\"facebook/wav2vec2-base\")\n", "\n", "dataset_sampling_rate = next(iter(raw_datasets.values())).features[\"audio\"].sampling_rate\n", "if dataset_sampling_rate != 16000:\n", " raw_datasets = raw_datasets.cast_column(\"audio\", datasets.features.Audio(sampling_rate=16000))\n", "\n", "max_input_length = int(15.6 * feature_extractor.sampling_rate)\n", "min_input_length = int(2.0 * feature_extractor.sampling_rate)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "bdd71404", "metadata": {}, "source": [ "### Prepare dataset\n", "\n", "In this step, both the feature extraction and tokenization are applied to the audio and transcript, respectively. The feature extractor resamples the audio, and the tokenizer will convert the normalised text into indexes.\n", "\n", "After the `map` function has completed, the dataset will be filtered by the audio length. If the length of the raw audio is not between 2.0 and 15.6 seconds then it will be removed from the data. The result of the filtering is cached." ] }, { "cell_type": "code", "execution_count": null, "id": "f052fcf7", "metadata": {}, "outputs": [], "source": [ "def prepare_dataset(batch):\n", " sample = batch[\"audio\"]\n", "\n", " inputs = feature_extractor(sample[\"array\"], sampling_rate=sample[\"sampling_rate\"])\n", " batch[\"input_values\"] = inputs.input_values[0]\n", " batch[\"input_length\"] = len(inputs.input_values[0])\n", "\n", " batch[\"labels\"] = tokenizer(batch[\"target_text\"]).input_ids\n", "\n", " return batch\n", "\n", "\n", "def is_audio_in_length_range(length):\n", " try:\n", " return length > min_input_length and length < max_input_length\n", " except:\n", " return False\n", "\n", "\n", "vectorized_datasets = raw_datasets.map(prepare_dataset,\n", " remove_columns=raw_datasets[\"train\"].column_names,\n", " num_proc=8,\n", " desc=\"preprocess datasets\")\n", "\n", "vectorized_datasets = vectorized_datasets.filter(is_audio_in_length_range,\n", " input_columns=[\"input_length\"],\n", " num_proc=8)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "2dcce07e", "metadata": {}, "source": [ "## Data loading\n", "\n", "With the dataset prepared, the majority of the processing is complete and the data is almost ready to be sent to the model. The role of the collator is to pad the resampled audio and encoded text to a static size. The padding values for audio will be set to `0.0` but for the indexes they will be `-100` so it's not confused with an index in the vocabulary." ] }, { "cell_type": "code", "execution_count": null, "id": "3f57898b", "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class DataCollatorCTCWithPadding:\n", " \"\"\"\n", " Data collator that will dynamically pad the inputs received.\n", " Args:\n", " processor (:class:`~transformers.AutoProcessor`)\n", " The processor used for processing the data.\n", " padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):\n", " Select a strategy to pad the returned sequences (according to the model's padding side and padding index)\n", " among:\n", " * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single\n", " sequence if provided).\n", " * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the\n", " maximum acceptable input length for the model if that argument is not provided.\n", " * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of\n", " different lengths).\n", " max_length (:obj:`int`, `optional`):\n", " Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).\n", " max_length_labels (:obj:`int`, `optional`):\n", " Maximum length of the ``labels`` returned list and optionally padding length (see above).\n", " pad_to_multiple_of (:obj:`int`, `optional`):\n", " If set will pad the sequence to a multiple of the provided value.\n", " This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=\n", " 7.5 (Volta).\n", " \"\"\"\n", "\n", " processor: AutoProcessor\n", " padding: Union[bool, str] = \"longest\"\n", " pad_to_multiple_of: Optional[int] = None\n", " pad_to_multiple_of_labels: Optional[int] = None\n", "\n", " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n", " # split inputs and labels since they have to be of different lenghts and need\n", " # different padding methods\n", " input_features = [{\"input_values\": feature[\"input_values\"]} for feature in features]\n", " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n", "\n", " batch = self.processor.pad(\n", " input_features,\n", " padding=self.padding,\n", " pad_to_multiple_of=self.pad_to_multiple_of,\n", " return_tensors=\"pt\",\n", " )\n", "\n", " with self.processor.as_target_processor():\n", " labels_batch = self.processor.pad(\n", " label_features,\n", " padding=self.padding,\n", " pad_to_multiple_of=self.pad_to_multiple_of_labels,\n", " return_tensors=\"pt\",\n", " )\n", "\n", " # replace padding with -100 to ignore loss correctly\n", " batch[\"labels\"] = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n", " batch[\"input_values\"] = batch[\"input_values\"].half()\n", "\n", " return batch.data\n", "\n", "\n", "processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)\n", "\n", "data_collator = DataCollatorCTCWithPadding(processor=processor, pad_to_multiple_of=int(max_input_length),\n", " pad_to_multiple_of_labels=1000)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "3a56b494", "metadata": {}, "source": [ "## Preparing the model\n", "\n", "For the model, we are using `wav2vec2-base` from the [🤗 Models Hub](https://huggingface.co/models). This model has been pre-trained only.\n", "Some of the default options for the model will need to be changed for training:\n", "* CTC loss will be normalised by the lengths.\n", "* There is no masking of the features to be applied so both masks are set to 0.0. The current masking strategy isn't supported on the IPU.\n", "* The [PAD] index and vocabulary size are later used in the model for the final output layer and CTC-loss.\n", "* Epsilon is adjusted for FP16 training.\n", "\n", "The IPU config describes how to parallelise the model across several IPUs. It also includes additional options such as gradient accumulation, device iterations, and memory proportion. " ] }, { "cell_type": "code", "execution_count": null, "id": "53b0e35e", "metadata": {}, "outputs": [], "source": [ "config = AutoConfig.from_pretrained(\"facebook/wav2vec2-base\")\n", "config.update(\n", " {\n", " \"ctc_loss_reduction\": \"mean\",\n", " \"mask_time_prob\": 0.0,\n", " \"mask_feature_prob\": 0.0,\n", " \"layerdrop\": 0.0,\n", " \"pad_token_id\": tokenizer.pad_token_id,\n", " \"vocab_size\": len(tokenizer),\n", " \"layer_norm_eps\": 0.0001,\n", " }\n", ")\n", "\n", "model = AutoModelForCTC.from_pretrained(\"facebook/wav2vec2-base\", config=config)\n", "\n", "ipu_config = IPUConfig.from_pretrained(\"Graphcore/wav2vec2-ctc-base-ipu\", executable_cache_dir=executable_cache_dir)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "61c69b73", "metadata": {}, "source": [ "Let's set our training hyperparameters using `IPUTrainingArguments`. This subclasses the Hugging Face `TrainingArguments` class, adding parameters specific to the IPU and its execution characteristics." ] }, { "cell_type": "code", "execution_count": null, "id": "5d0c80b0", "metadata": {}, "outputs": [], "source": [ "training_args = IPUTrainingArguments(output_dir= checkpoint_directory,\n", " overwrite_output_dir=True,\n", " do_train=True,\n", " do_eval=True,\n", " evaluation_strategy=\"epoch\",\n", " learning_rate=3e-4,\n", " num_train_epochs=5.0,\n", " adam_epsilon=0.0001,\n", " warmup_steps=400,\n", " dataloader_drop_last=True,\n", " dataloader_num_workers=16,\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "082a3007", "metadata": {}, "outputs": [], "source": [ "feature_extractor.save_pretrained(training_args.output_dir)\n", "tokenizer.save_pretrained(training_args.output_dir)\n", "processor.save_pretrained(training_args.output_dir)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "d8c7df96", "metadata": {}, "source": [ "The performance of the model is measured using the WER. This metric takes a predicted string and the correct string and computes an [edit distance](https://en.wikipedia.org/wiki/Edit_distance) normalised by the length of the string. \n", "\n", "To add this metric to our evaluation, we define a `compute_metrics` function and load the metric from the `datasets` package. This is performed once after all the evaluation outputs have been computed." ] }, { "cell_type": "code", "execution_count": null, "id": "77d1ba33", "metadata": {}, "outputs": [], "source": [ "eval_metrics = {\"wer\": load_metric(\"wer\")}\n", "\n", "\n", "def compute_metrics(pred):\n", " pred_logits = pred.predictions\n", " pred_ids = np.argmax(pred_logits, axis=-1)\n", "\n", " pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id\n", "\n", " pred_str = tokenizer.batch_decode(pred_ids)\n", " # we do not want to group tokens when computing the metrics\n", " label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)\n", "\n", " metrics = {k: v.compute(predictions=pred_str, references=label_str) for k, v in eval_metrics.items()}\n", "\n", " return metrics" ] }, { "attachments": {}, "cell_type": "markdown", "id": "97894b6e", "metadata": {}, "source": [ "To train the model, we define a trainer using the `IPUTrainer` class which takes care of compiling the model to run on IPUs, and of performing training and evaluation. The `IPUTrainer` class works just like the Hugging Face `Trainer` class, but takes the additional `ipu_config` argument." ] }, { "cell_type": "code", "execution_count": null, "id": "b4f79a20", "metadata": {}, "outputs": [], "source": [ "# Initialize Trainer\n", "trainer = IPUTrainer(\n", " model=model,\n", " ipu_config=ipu_config,\n", " data_collator=data_collator,\n", " args=training_args,\n", " compute_metrics=compute_metrics,\n", " train_dataset=vectorized_datasets[\"train\"],\n", " eval_dataset=vectorized_datasets[\"eval\"],\n", " tokenizer=feature_extractor,\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "006e91b6", "metadata": {}, "source": [ "## Run the training" ] }, { "cell_type": "code", "execution_count": null, "id": "e04a4b42", "metadata": { "scrolled": true }, "outputs": [], "source": [ "trainer.train()\n", "trainer.save_model()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "98f4e5cb", "metadata": {}, "source": [ "## Next steps\n", "\n", "You can try out the notebook on Automatic Speech Recognition (ASR) on IPU using wav2vec - Inference `wav2vec2-inference-checkpoint.ipynb` to use the outputs of this notebook.\n", "\n", "Also, Try out the other [IPU-powered Jupyter Notebooks](https://www.graphcore.ai/ipu-jupyter-notebooks) to see how how IPUs perform on other tasks." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "vscode": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } } }, "nbformat": 4, "nbformat_minor": 5 }