notebooks/packed_bert/packedBERT_question_answering.ipynb (1,177 lines of code) (raw):
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "156ed6e4",
"metadata": {},
"source": [
"# Faster Question-answering on IPU using Packed BERT - Fine-tuning and Inference\n",
"\n",
"This notebook describes how to fine-tune BERT from [🤗 Transformers](https://github.com/huggingface/transformers) for question answering using the SQuAD(v1) dataset using [packing](https://towardsdatascience.com/introducing-packed-bert-for-2x-faster-training-in-natural-language-processing-eadb749962b1), an optimisation method originally used for 2x faster BERT pre-training, which can now also provide massive throughput increases for fine-tuning and batched inference! \n",
"\n",
"**So, what *is* packing?** The basic idea of \"packing\" a dataset is to utilise the requirement for constant-shaped inputs into a model. Instead of padding it with empty, unused space, we can recycle this unused space and fill it with more inputs! The architecture of transformer models like BERT supports this, and lets us optimally use this space to process multiple sequences within one input.\n",
"\n",
"**And here is why you might want to use it:** Having a single input that contains multiple sequences leads to multiple sequences being processed in parallel in a single pass within a single iteration inside a batch, increasing the \"effective\" batch size of the model by a considerable factor in many cases, and most importantly, increasing model throughput for training and batched inference significantly.\n",
"\n",
"This notebook builds on the process described in the notebook on \"fine-tuning a model on a question-answering task\" `natural-language-processing/other-use-cases/question_answering.ipynb`. In this notebook, we will how to adapt the the `BertForQuestionAnswering` model to accommodate a packed dataset. "
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "2cd5485f",
"metadata": {},
"source": [
"| Domain | Tasks | Model | Datasets | Workflow | Number of IPUs | Execution time |\n",
"|---------|-------|-------|----------|----------|--------------|--------------|\n",
"| Natural language processing | Question answering | PackedBERT | SQuADv1| Inference | | |"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "05ee9609",
"metadata": {},
"source": [
"[](https://www.graphcore.ai/join-community)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "e9f7a847",
"metadata": {},
"source": [
"## Environment setup\n",
"\n",
"The best way to run this demo is on Paperspace Gradient's cloud IPUs because everything is already set up for you.\n",
"\n",
"To run the demo using other IPU hardware, you need to have the Poplar SDK enabled. Refer to the [Getting Started guide](https://docs.graphcore.ai/en/latest/getting-started.html#getting-started) for your system for details on how to enable the Poplar SDK. Also refer to the [Jupyter Quick Start guide](https://docs.graphcore.ai/projects/jupyter-notebook-quick-start/en/latest/index.html) for how to set up Jupyter to be able to run this notebook on a remote IPU machine."
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "6bc957c3",
"metadata": {},
"source": [
"## Dependencies and configuration"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "48b6d85c",
"metadata": {},
"source": [
"In order to improve usability and support for future users, Graphcore would like to collect information about the\n",
"applications and code being run in this notebook. The following information will be anonymised before being sent to Graphcore:\n",
"\n",
"- User progression through the notebook\n",
"- Notebook details: number of cells, code being run and the output of the cells\n",
"- Environment details\n",
"\n",
"You can disable logging at any time by running `%unload_ext graphcore_cloud_tools.notebook_logging.gc_logger` from any cell."
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "f17ecdb5",
"metadata": {},
"source": [
"Install the dependencies for this notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e98ec027",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"%pip install optimum-graphcore==0.7\n",
"%pip install scikit-learn \"datasets>=2.7.0\" evaluate tokenizers matplotlib scipy huggingface_hub\n",
"\n",
"%pip install graphcore-cloud-tools[logger]@git+https://github.com/graphcore/graphcore-cloud-tools\n",
"%load_ext graphcore_cloud_tools.notebook_logging.gc_logger"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "92ae1cca",
"metadata": {},
"source": [
"Values for machine size and cache directories can be configured through environment variables or directly in the notebook:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b882a5b3",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"\n",
"n_ipu = int(os.getenv(\"NUM_AVAILABLE_IPU\", 4))\n",
"executable_cache_dir = os.getenv(\"POPLAR_EXECUTABLE_CACHE_DIR\", \"./exe_cache/\") + \"/packed_bert_squad/\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "6e8ee76f",
"metadata": {},
"source": [
"In this notebook, we use both data parallelism and pipeline parallelism (see this [tutorial on efficient data loading](https://github.com/graphcore/examples/blob/master/tutorials/tutorials/pytorch/efficient_data_loading/walkthrough.ipynb) for more information), so the \"global\" batch size, which is the number of data elements passed for one gradient calculation on the IPU, is calculated using the number of device iterations (`device_iterations`), the number of gradient accumulation steps (`gradient_accumulation_steps`), the replication factor (`replication_factor`) and the maximum number of sequences in a pack (`max_seq_per_pack`) for training, such that:\n",
"\n",
"```\n",
"global_training_batch_size = micro_batch_size * device_iterations * gradient_accumulation_steps * replication_factor\n",
"```\n",
"\n",
"Note: we define a \"micro\" batch size, which is the local batch size that would be passed into the model on the CPU.\n",
"\n",
"Depending on your model and the IPU Pod machine you are using, you might need to adjust these three batch-size-related arguments.\n",
"\n",
"`max_seq_per_pack` highlights the benefit of packing multiple sequences into one input sequence given there is enough space for them. It shows that multiple sequences are processed effectively in parallel within the model, using up space that would essentially be padding if one sequence were passed at a time. This is a much more efficient way to send inputs into the model, and improves the global batch size to a best-case-scenario of:\n",
"\n",
"```\n",
"global_training_batch_size = micro_batch_size * device_iterations * gradient_accumulation_steps * replication_factor * max_seq_per_pack\n",
"```\n",
"\n",
"Realistically, the global batch size will not always be multiplied by the *maximum* number of sequences in a packed sequence, but rather the *average* number of sequences in a packed sequence, and will depend on the sequence length distribution within any given dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ad1b478",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"model_checkpoint=\"bert-base-uncased\" # Default uncased pre-trained BERT checkpoint\n",
"ipu_config_name=\"Graphcore/bert-base-uncased\" # Default Graphcore IPU config initialisation for pre-trained BERT\n",
"max_seq_length=512 # The maximum sequence length allowed for sequences in the model.\n",
"number_ipus_per_replica = 4\n",
"number_of_replicas = n_ipu // number_ipus_per_replica\n",
"# The size of the machine and the length of the pipeline impact the number of \n",
"# samples that need to be processed between gradient updates\n",
"gradient_accumulation_steps = 32 // number_of_replicas\n",
"device_iterations = 32\n",
"micro_batch_size=2\n",
"model_task=\"squad\" "
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "77dde875",
"metadata": {},
"source": [
"Gradients are not calculated during validation, so gradient accumulation is not applicable, and the global batch size for validation can be defined separately as:\n",
"\n",
"```\n",
"global_validation_batch_size=micro_batch_size*device_iterations*replication_factor*max_seq_per_pack\n",
"```\n",
"\n",
"In Graphcore Optimum, we can define inference-specific values for `device iterations` and `replication factor`, which can be adjusted to create larger batches to compensate for the lack of a gradient accumulation factor."
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "33597c71",
"metadata": {},
"source": [
"## Loading the dataset\n",
"\n",
"The next step is to use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the dataset from the 🤗 Hub, and to use the [🤗 Evaluate](https://github.com/huggingface/evaluate) library to load the evaluation metrics for the SQuAD model. This will allow easy performance metric analysis during validation."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b37cb293",
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"from datasets import load_dataset, load_metric\n",
"import evaluate\n",
"\n",
"dataset = load_dataset(model_task) # Load dataset\n",
"metric = evaluate.load(model_task) # Load metric for dataset"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "6dac6eca",
"metadata": {},
"source": [
"The `dataset` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key each for the training, validation and test sets:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2115928b",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"dataset"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "23d3f421",
"metadata": {},
"source": [
"To access an actual element, you need to select a split (\"train\" in the example) and then specify an index:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "311b8b73",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"dataset[\"train\"][0]"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "3702f2a3",
"metadata": {},
"source": [
"In the SQuAD dataset, we have a `question`, its `context` which is an excerpt of text which includes the answer as well as surrounding context, and the `answer` key, which holds the start position of the answer in the context, as well as the answer itself. For a different or custom question-answering dataset, these fields may have different names but serve the same purpose, so pre-defining them is useful.\n",
"\n",
"We have a configuration describing these necessary keys in the dataset containing the raw data that needs to be pre-processed or tokenised before being passed into the model. These generic keys may change for custom datasets, but the usage of them generally stays the same for a similar fine-tuning task."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "628bc41f",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"question_key=\"question\"\n",
"context_key=\"context\"\n",
"answer_key=\"answers\"\n",
"train = True\n",
"validate = True"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "793dcd19",
"metadata": {},
"source": [
"## Preprocessing the data\n",
"\n",
"Before we can feed these text samples to our model, we need to preprocess them. This is done by a 🤗 Transformers `Tokenizer` which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pre-trained vocabulary) and put it in a format the model expects, as well as generate the other inputs that the model requires.\n",
"\n",
"To do all of this, we instantiate our tokenizer with the `AutoTokenizer.from_pretrained` method, which will ensure:\n",
"\n",
"- We get a tokenizer that corresponds to the model architecture we want to use.\n",
"- We download the vocabulary used when pre-training this specific checkpoint.\n",
"\n",
"That vocabulary will be cached, so it's not downloaded again the next time we run the cell.\n",
"\n",
"The `Dataset` method is also imported, which will allow us to convert our modified and tokenized columns in dictionary form to a dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aab94819",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"from datasets import Dataset \n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "a47ea927",
"metadata": {},
"source": [
"For SQuAD, we define a custom function to handle the overflows and offset mapping created by generating tokenised inputs from sequences, as well as the start and end positions of the answers which need to be translated from positions of characters to positions of tokens.\n",
"\n",
"The first step is to tokenize the dataset using the tokenizer. Note here that for packing, it is important to **not** pad the dataset, so `padding` should be set to `False`. If we pad, we will have to un-pad when packing sequences into a packed sequence, which is inefficient.\n",
"\n",
"The details of the preprocessing function is outlined in the original (unpacked) question-answering notebook `natural-language-processing/other-use-cases/question_answering.ipynb`. In this case, we can import the preprocessing directly from `utils.packing`, *without* padding for PackedBERT."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2263dfef",
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"from utils.packing.qa_utils import preprocess_packed_qa\n",
"\n",
"raw_train_dataset = dataset['train']\n",
"\n",
"tokenized_training_dataset = preprocess_packed_qa(\n",
" dataset=raw_train_dataset,\n",
" tokenizer=tokenizer,\n",
" question_key=question_key,\n",
" context_key=context_key,\n",
" answer_key=answer_key,\n",
" sequence_length=max_seq_length,\n",
" padding=False,\n",
" train=True\n",
")\n",
"\n",
"\n",
"raw_validation_dataset = dataset['validation']\n",
"\n",
"tokenized_validation_dataset = preprocess_packed_qa(\n",
" dataset=raw_validation_dataset,\n",
" tokenizer=tokenizer,\n",
" question_key=question_key,\n",
" context_key=context_key,\n",
" answer_key=answer_key,\n",
" sequence_length=max_seq_length,\n",
" padding=False,\n",
" train=False\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "f57906e8",
"metadata": {},
"source": [
"## Packing the dataset\n",
"\n",
"To implement packing, we need to pack our dataset first. Each new element will be a \"pack\" containing at most `max_seq_per_pack` sequences."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6bdd1b9e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"max_seq_per_pack = 6"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "51c17c9b",
"metadata": {},
"source": [
"We define the number of labels in our dataset. For SQuAD, this means the number of outputs, which are the positions returned by the model. Since this is not a classification task, `num_labels` is set to 2, to correspond to start and end positions.\n",
"\n",
"We also define the problem type."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bfda406f",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"num_labels = 2\n",
"problem_type = 'question_answering'"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "c39316ea",
"metadata": {},
"source": [
"### Packing algorithm"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "d9d24ab5",
"metadata": {},
"source": [
"In order to pack efficiently, we will use a histogram-based algorithm. The shortest-pack-first histogram packing (SPFHP) was presented in the Graphcore blog post [introducing Packed BERT for a training speedup in natural language processing](https://www.graphcore.ai/posts/introducing-packed-bert-for-2x-faster-training-in-natural-language-processing). We have adapted the [code](https://github.com/graphcore/tutorials/tree/master/blogs_code/packedBERT) from the blog post for this notebook. The full process of packing the dataset consists of four steps:\n",
"\n",
"1. Create a histogram of the sequence lengths of the dataset.\n",
"2. Generate the \"strategy\" for the dataset using one of the state-of-the-art packing algorithms. The strategy maps out the order and indices of the sequences that need to be packed together.\n",
"3. Use this strategy to create the actual dataset, concatenating the tokenized features together for each column in the dataset, including the labels.\n",
"4. Finally, pass these new columns into a custom PyTorch dataset, ready to be passed to the PopTorch dataloader!\n",
"\n",
"These steps have been simplified through the easy-to-use `utils.packing` package available in Graphcore Optimum. You can simply generate the packed dataset after the usual tokenization and preprocessing by passing all necessary packing configuration to the `PackedDatasetCreator` class, and generate the ready-to-use PyTorch dataset with `.create()`.\n",
"\n",
"Within the function, there are some column names used by default. The expected default columns for text classification include:\n",
"* `input_ids`\n",
"* `attention_mask`\n",
"* `token_type_ids`\n",
"* `labels`\n",
"\n",
"These should all be generated automatically when tokenizing any classification dataset for BERT. However, the labels key, as it is not encoded, may have a different name. For this dataset, the column key for the labels for this dataset is `label`, since the dataset creator expects `labels`, we can pass this to the argument `custom_label_key`, so the class can find our labels. \n",
"\n",
"The `PackedDatasetCreator` requires different instantiations for different datasets, so it must be called separately for each of our dataset splits. We can set either `training`, `validation` or `inference` to `True` as needed."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e66ed06d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from utils.packing.dataset_creator import PackedDatasetCreator\n",
"\n",
"train_data_packer = PackedDatasetCreator(\n",
" tokenized_dataset = tokenized_training_dataset,\n",
" max_sequence_length = max_seq_length,\n",
" max_sequences_per_pack = max_seq_per_pack,\n",
" training = True,\n",
" num_labels = num_labels,\n",
" problem_type = problem_type,\n",
" algorithm = 'SPFHP'\n",
")\n",
"\n",
"val_data_packer = PackedDatasetCreator(\n",
" tokenized_dataset = tokenized_validation_dataset,\n",
" max_sequence_length = max_seq_length,\n",
" max_sequences_per_pack = max_seq_per_pack,\n",
" validation = True,\n",
" num_labels = num_labels,\n",
" problem_type = problem_type,\n",
" algorithm = 'SPFHP'\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "720ea314",
"metadata": {},
"source": [
"This will create the strategy and initialise the necessary parameters for packing the dataset. We can see that the ideal speed-up we have achieved is approximately 2.2x the original dataset, which corresponds directly to the average packing factor, which is the average number of sequences within one pack.\n",
"\n",
"The `PackedDatasetCreator` class also has some other features we do not use here for training, such as `pad_to_global_batch_size`. This feature is useful for performing batched inference on a large set of samples when we do not want to lose any of the samples when creating data iterators using the `poptorch.Dataloader`. In this case, it applies 'vertical' padding to the dataset, adding filler rows to bring the dataset up to a value divisible by the global batch size, and allows for the largest possible batch sizes to be used without any loss of data."
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "46319488",
"metadata": {},
"source": [
"You can also view the histogram generated in the first step of the packing process, to observe whether the distribution of sequence lengths in the dataset will benefit from packing. As a general rule, as long as the average length of the sequences in the dataset is 50% or less of the maximum sequence length, packing will offer at least a 2x throughput benefit, in other words: `throughput_increase ≈ max_seq_len/mean_seq_len`\n",
"\n",
"Many datasets have distributions with much smaller average lengths, and will benefit much more. We can easily observe this distribution by retrieving and plotting the histogram from the data class:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "113b58f4",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"train_histogram = train_data_packer.histogram\n",
"\n",
"plt.hist(train_histogram, bins = [k for k in range(0,max_seq_length,10)]) \n",
"plt.title(\"Sequence length histogram\") \n",
"plt.xlabel('Sequence lengths')\n",
"plt.ylabel('Frequency')\n",
"plt.show()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "1d077b97",
"metadata": {},
"source": [
"Now we need to create the actual packed dataset (step 3 of the packing process outlined above).\n",
"\n",
"In this stage, we take the strategy for mapping the sequences by size into \"packs\" that were generated by the packing algorithm, and use this to extract the sequences from the tokenized dataset, inserting them into packs for each column in the dataset. Any remaining space in a pack after the sequences have been concatenated is padded to bring all sequences up to the maximum sequence length.\n",
"\n",
"**Some key features unique to packed datasets are worth mentioning here**:\n",
"\n",
"- The specific attention mask (`attention_mask`) that is generated contains a unique index for each sequence of the pack and `0` for the remaining padding tokens. This, essentially, tells the model where to look from the perspective of a single token, ignoring any encoded information (such as a different sequence) that is not relevant to that token.\n",
" - Example of 3 sequences in a pack: `attention_mask = [1,1,1,1,1,1,2,2,2,2,2,3,3,3,3,3,0,0,0]`\n",
" - Compared to a single sequence in an unpacked input `attention_mask = [1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0]`\n",
" \n",
"\n",
"- `position_ids` for a pack contain the concatenated `position_ids` of each sequences \n",
" - For instance given 3 sequences: `[0,1,2,3,4] + [0,1,2,3] + [0,1,2] -> [1,2,3,4,1,2,3,1,2,...,0,0,0]` (note: the CLS tokens position id '0' are also moved the end of the pack)\n",
" \n",
" \n",
"- For SQuAD, during training, answers are determined using a start position and end position within the sequence. During preprocessing, these were converted from character positions to token positions. Now, during packing, as tokenized sequences are effectively being concatenated along the same dimension, the positions of the answer will change for any sequence that is not starting at index 0 within a pack. For example, in a pack with 2 sequences:\n",
" - Answer positions before packing:\n",
" ```\n",
" Length of sequence 1: 100 tokens (index 0 to 99) , start position: 30, end position: 35\n",
" Length of sequence 2: 120 tokens (index 0 to 119) , start position: 15, end position: 25\n",
" ```\n",
" - Answer positions after packing:\n",
" ```\n",
" Length of sequence 1 in pack 1: 100 tokens (index 0 to 99) , start position: 30, end position: 35\n",
" Length of sequence 2 in pack 1: 120 tokens (index 100 to 219), start position: 115, end position: 125 \n",
" ```\n",
"\n",
" - The positions have been shifted by the total length of preceding sequences in the pack. We call this the position offset.\n",
"\n",
"\n",
"To create a dataloader-ready packed dataset, all you need to do is call the `create()` method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bdcc161d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"packed_train_dataset = train_data_packer.create()\n",
"packed_val_dataset = val_data_packer.create()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "ce443c8f",
"metadata": {},
"source": [
"Let's visualize one sample of the new `packed_train_dataset`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c966cd9a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"packed_train_dataset[133]"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "a1d8ce9c",
"metadata": {},
"source": [
"## Fine-tuning the model\n",
"\n",
"Now that our data is ready, we can download the pre-trained model and fine-tune it.\n",
"\n",
"### Implement Packed BERT\n",
"\n",
"Some model modifications are required to make packing work with BERT. For SQuAD, we create a custom output class to separate the logits according to each of the sequences within the pack and calculate the loss. The existing class `BertForQuestionAnswering` is extended to `PipelinedPackedBertForQuestionAnswering` which incorporates the required modifications to the model. The crux of these changes is to introduce the new attention mask, and modify the hidden layer output of the model to mask any padded inputs from the logits.\n",
"\n",
"First let's load a default BERT configuration using `AutoConfig`. The config includes a new parameter we must set, `max_sequences_per_pack`. This informs the model of the maximum number of sequences it will need to unpack in the model output. It also allows us to clearly define `num_labels` and `problem_type` for this model."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "254a0f83",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from transformers import AutoConfig\n",
"\n",
"config = AutoConfig.from_pretrained(model_checkpoint)\n",
"config.max_sequences_per_pack = max_seq_per_pack\n",
"config.num_labels = num_labels"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "0b7dae37",
"metadata": {},
"source": [
"Now we can instantiate the model class with the config, loading the weights from the model checkpoint. For SQuAD, we can determine the number of \"labels\" as the two output types that will determine whether answers are correct or not, which means the start and end position."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3285aaa3",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"torch.manual_seed(43)\n",
"np.random.seed(43)\n",
" \n",
"from models.modeling_bert_packed import PipelinedPackedBertForQuestionAnswering\n",
"\n",
"model = PipelinedPackedBertForQuestionAnswering.from_pretrained(model_checkpoint, config=config)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "d6070000",
"metadata": {},
"source": [
"The warning tells us we are throwing away some weights and randomly initializing others. This is absolutely normal in this case, because we are removing the head used to pre-train the model on a masked language modelling objective and replacing it with a new head for question answering, for which we don't have pre-trained weights, so the library warns us we should fine-tune this model before using it for inference, which is exactly what we are going to do.\n",
"\n",
"We can first test the model on a CPU."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "02aac4e1",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# test the model on CPU\n",
"from transformers.data.data_collator import default_data_collator\n",
"\n",
"loader = torch.utils.data.DataLoader(packed_train_dataset,\n",
" batch_size=2,\n",
" shuffle=True,\n",
" drop_last=True,\n",
" collate_fn=default_data_collator)\n",
"data = next(iter(loader))\n",
"o = model(**data)\n",
"print(\"Logits shape:\", o)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "5c9922fa",
"metadata": {},
"source": [
"Now, let's prepare the model for an IPU.\n",
"\n",
"First, we set the model in half-precision:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11502853",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"model.half()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "7a9ab06f",
"metadata": {},
"source": [
"### Define validation metrics for SQuAD\n",
"\n",
"Before training and evaluating, a custom post-processing function needs to be defined for SQuAD. This is because we need to map the predictions of the model back to parts of the context in terms of the character positions in the original untokenized samples. The model predicts logits for the start and end token position of the answer.\n",
"\n",
"The purpose of the function is to identify each of the tokenized features according to `example_ids` and map the start and end token positions for the output, taking the top-*n* logit indices and discarding all invalid solutions. It then uses `offset_mapping` to map the start and end token-level positions back to character-level positions within the context, and generates a text answer using the original context. This text prediction can then be used to calculate accuracy metrics and can be compared to the target answer present in the dataset.\n",
"\n",
"The `postprocess_qa_predictions()` function is adapted for packing, taken directly from the existing [tutorial for fine-tuning SQuAD for the IPU](https://github.com/huggingface/optimum-graphcore/blob/main/notebooks/question_answering.ipynb) for an unpacked dataset. The full description for the use of this function is described in that tutorial. \n",
"\n",
"The main changes to the function for packing include: \n",
"* Instead of iterating through all the features in the tokenized dataset, and obtaining the `example_id` field created during tokenization of the validation dataset, this function iterates through each feature within each pack, obtaining the corresponding `example_id` value for each feature within the pack. \n",
"\n",
"* It saves the index of the pack in the dataset, **as well as the index of the feature within the pack**, to allow the function to easily and linearly obtain the features to perform validation on.\n",
"\n",
"This post-processing is available ready-to-use from the packing utils: `utils.packing`, and can simply be initialised."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f4c6e4d0",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from utils.packing.qa_utils import postprocess_packed_qa_predictions"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "38c76b0b",
"metadata": {},
"source": [
"Finally, a `compute_validation_metrics` function is created to take in the post-processed predictions. This obtains the answers from the dataset, maps them according to their `example_id` value to the corresponding prediction, and uses `metric` from the 🤗 Evaluate library to compute the relevant metrics for SQuAD, including an \"exact match\" accuracy, as well as an F1 score, for each answer. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5420ef6a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def compute_validation_metrics(predictions, raw_validation_dataset, packed_validation_dataset_unformatted, metric):\n",
" \n",
" target_answers = [\n",
" {\"id\": ex[\"id\"], \"answers\": ex[\"answers\"]} for ex in raw_validation_dataset\n",
" ]\n",
" \n",
" final_predictions = postprocess_packed_qa_predictions(\n",
" raw_validation_dataset, packed_validation_dataset_unformatted, predictions\n",
" )\n",
"\n",
" formatted_predictions = [\n",
" {\"id\": k, \"prediction_text\": v} for k, v in final_predictions.items()\n",
" ]\n",
"\n",
" metrics = metric.compute(predictions=formatted_predictions, references=target_answers)\n",
" \n",
" return metrics\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "acdde982",
"metadata": {},
"source": [
"### Train and validate the model using the 🤗 Optimum Graphcore `IPUTrainer`\n",
"\n",
"Now let's prepare the model for the IPU. Instantiate the options and machine configurations and create `IPUTrainer` to efficiently and easily perform training on the IPU in just a few lines.\n",
"\n",
"We need to define `IPUConfig`, which is a class that specifies attributes and configuration parameters to compile and put the model on the device. We initialize it with a config name or path, which we set earlier. Then we use it to set the mode attribute `model.ipu_config` "
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "e2073fcd",
"metadata": {},
"source": [
"As we are using a pre-trained checkpoint, we can use the existing IPU configuration for `Graphcore/bert-base-uncased` for the custom model. This should require no changes as even though the model has been modified to be compatible with a packed dataset, the pipelining stages and IPU options will remain the same. \n",
"\n",
"Some of the options have been specified when defining `ipu_config` to highlight the global batch size. This uses the configurations defined at the beginning of this notebook. Note that we can also define inference-specific device iterations and replication factors for performing validation on the model, to modify the validation global batch size."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b0452e1",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from optimum.graphcore import IPUConfig, IPUTrainer, IPUTrainingArguments\n",
"\n",
"ipu_config = IPUConfig.from_pretrained(\n",
" ipu_config_name,\n",
" executable_cache_dir = executable_cache_dir,\n",
" gradient_accumulation_steps=gradient_accumulation_steps,\n",
" device_iterations=device_iterations,\n",
" replication_factor=1,\n",
" embedding_serialization_factor=1,\n",
" inference_device_iterations= 64,\n",
" inference_replication_factor=1,\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "0a6b8635",
"metadata": {},
"source": [
"To instantiate the `IPUTrainer` class, we will need to define `IPUTrainingArguments`, which is a class that contains all the attributes to customize the training. It requires one folder name, which will be used to save the checkpoints of the model. All other arguments are optional:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "141a2e2d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"training_args = IPUTrainingArguments(\n",
" output_dir=f\"./{model_checkpoint}-{model_task}\",\n",
" per_device_train_batch_size=micro_batch_size,\n",
" per_device_eval_batch_size=8,\n",
" num_train_epochs=3,\n",
" learning_rate=9e-05,\n",
" loss_scaling=64.0,\n",
" weight_decay=0.01,\n",
" warmup_ratio=0.25,\n",
" lr_scheduler_type='cosine',\n",
" n_ipu=n_ipu,\n",
" gradient_accumulation_steps=gradient_accumulation_steps,\n",
" dataloader_mode=\"async_rebatched\",\n",
" dataloader_drop_last=True,\n",
" dataloader_num_workers=64,\n",
" logging_steps=5\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "c5e150ed",
"metadata": {},
"source": [
"**Note that we do not set evaluation to be performed during the training process for SQuAD**. This is due to the custom post-processing steps required to extract text-level answers for SQuAD, for which the logits cannot be easily modified without multiple function inputs, such as the tokenized and raw datasets, while the `preprocess_logits_for_metrics` argument provided in `IPUTrainingArguments` can only utilise logits alone. Therefore, validation is done after training."
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "eeb965d0",
"metadata": {},
"source": [
"We will need a data collator that will batch our processed samples together. Here we will use the default data collator imported from the [🤗 Transformers library](https://huggingface.co/docs/transformers/index). This is passed to the `IPUTrainer` class. \n",
"\n",
"Then we just need to pass all of this along with our datasets to the IPUTrainer:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "561a41ca",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from transformers import default_data_collator\n",
"\n",
"trainer = IPUTrainer(\n",
" model=model,\n",
" ipu_config=ipu_config,\n",
" args=training_args,\n",
" train_dataset=packed_train_dataset,\n",
" data_collator=default_data_collator\n",
")\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "07b0933f",
"metadata": {},
"source": [
"We can now fine-tune our model by just calling the train method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c4cbe563",
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"train_run_metrics = trainer.train()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "c31ad4dc",
"metadata": {},
"source": [
"You can now upload the result of the training to the Hugging Face Hub if you successfully authenticated at the beginning of this notebook. Simply uncomment and execute the following cell:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b9e60061",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"#trainer.push_to_hub()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "a6d0fc29",
"metadata": {},
"source": [
"Then save the model with the model checkpoint name."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "625847dc",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"trainer.save_model(f\"./{model_checkpoint}-{model_task}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "c93fb854",
"metadata": {},
"source": [
"We can then perform the evaluation by using the `predict` method in `IPUTrainer`. This provides all of the raw predictions for the packed inputs for validation. This will, by default, use the global batch size defined specifically for inference in `IPUTrainingArguments`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c65f6830",
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"raw_predictions = trainer.predict(packed_val_dataset)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "7553b34d",
"metadata": {},
"source": [
"Once the predictions have been obtained, the validation metrics can be computed by passing them into `compute_validation_metrics` function. This, as described previously, performs the necessary post-processing on the logits and obtains text answers. It then computes the accuracy metrics (exact match and F1 score) for SQuAD fine-tuning."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "825dd9a8",
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"val_metrics = compute_validation_metrics(\n",
" raw_predictions.predictions, raw_validation_dataset, packed_val_dataset, metric)\n",
"\n",
"print(val_metrics)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "50eb0d90",
"metadata": {},
"source": [
"## Faster batched inference"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "6bda303c",
"metadata": {},
"source": [
"Packing can also be used for inference, particularly for performing inference for workloads. This section demonstrates how to perform faster, batched inference with a large number of samples using a super-easy custom pipeline which batches and packs your input data, performs inference and returns post-processed predictions. \n",
"\n",
"For the pipeline, we need to import it, and initialise a few essential parameters.\n",
"\n",
"The `model` is the model checkpoint, we are going to use the locally saved checkpoint generated from training SQuAD. Values for `executable_cache_dir` and `max_seq_length` must also be specified.\n",
"\n",
"The pipeline will automatically determine your model's IPU config, given that the checkpoint was trained using Optimum Graphcore, which will be the case for the model fine-tuned in this notebook.\n",
"\n",
"In this example, we pre-load the IPUConfig and modify some of the default parameters to get the best performance out of inference and leverage the benefits of IPU parallelism. The micro-batch size can also be specified, for which the default is 1.\n",
"\n",
"When training, the packing factor affects the convergence the same way as a large increase in batch size would do. However, for inference, we are free to use a bigger packing factor to speed it up. Let's try it with `max_seq_per_pack = 12`.\n",
"\n",
"**Note:** Packing brings huge benefits for performing inference on large amounts of data. For small scale inference tasks, such as those which more suit sequential inference on a single un-batched input, the generic Optimum Graphcore `TextClassificationPipeline` class may be preferred. This won't affect fine-tuning, and the weights generated from fine-tuning using packing will work just the same!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "50cd4b0e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from pipeline.packed_bert import PackedBertQuestionAnsweringPipeline\n",
"\n",
"model = \"./\"+f\"{model_checkpoint}-{model_task}\"\n",
"# model = 'your_username/{model_name}-{model_task}' # to load from Hugging Face Hub\n",
"\n",
"inference_boosted_ipu_config = IPUConfig.from_pretrained(model, \n",
" inference_device_iterations=32,\n",
" inference_replication_factor=4,\n",
" ipus_per_replica=1,\n",
" layers_per_ipu=[12]\n",
" )\n",
"\n",
"pipeline = PackedBertQuestionAnsweringPipeline(\n",
" model = f\"./{model_checkpoint}-{model_task}\",\n",
" executable_cache_dir = executable_cache_dir,\n",
" max_seq_per_pack=12,\n",
" max_seq_length=max_seq_length,\n",
" ipu_config=inference_boosted_ipu_config,\n",
" micro_batch_size=8\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "42d720aa",
"metadata": {},
"source": [
"The pipeline expects a **list of strings** directly passed to it in the format: \n",
"```\n",
"questions=[<list of questions>], contexts=[<list of contexts>]\n",
"```\n",
"There is no need to tokenize, preprocess, pack or post-process the data to use the inference pipeline.\n",
"\n",
"As a test, we can load the entire SQuAD validation dataset and perform packed inference using `.predict()` on the text column to generate predictions. Post-processing samples for SQuAD is done on a sample-by-sample, unbatched basis so this may take a few minutes with or without packing."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81969920",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import datasets\n",
"dataset = datasets.load_dataset('squad')\n",
"preds = pipeline.predict(questions=dataset['validation']['question'],\n",
" contexts=dataset['validation']['context'])\n",
"\n",
"print(preds.keys())\n",
"print(f\"Number of predictions: {len(preds['predictions'])}\")\n",
"print(f\"Preprocessing time: {preds['preprocessing_time']}s\")\n",
"print(f\"Postprocessing time: {preds['postprocessing_time']}s\")\n",
"print(f\"Throughput: {preds['throughput']} samples/s\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "3c77e5ac",
"metadata": {
"tags": []
},
"source": [
"There is minimal overhead from tokenizing and packing the dataset, but the speed benefits for inference are evident. Running the above pipeline, we achieve a throughput approximately 6000 samples per second, showing an approximate 2x speed up for SQuAD."
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "aef3a741",
"metadata": {},
"source": [
"## Next steps\n",
"\n",
"You may want to try out the following notebooks using BERT with packing for:\n",
"\n",
"* Single-label text classification - `packedBERT_single_label_text_classification.ipynb`\n",
"* Multi-label text classification - `packedBERT_multi_label_text_classification.ipynb`\n",
"\n",
"Also, check out the full list of [IPU-powered Jupyter Notebooks](https://www.graphcore.ai/ipu-jupyter-notebooks) to get more of a feel for how IPUs perform on other tasks."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"vscode": {
"interpreter": {
"hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}