course/videos/dynamic_padding.ipynb (202 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook regroups the code sample of the video below, which is a part of the [Hugging Face course](https://huggingface.co/course)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form" }, "outputs": [ { "data": { "text/html": [ "<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/7q5NyFT8REg?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#@title\n", "from IPython.display import HTML\n", "\n", "HTML('<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/7q5NyFT8REg?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Install the Transformers and Datasets libraries to run this notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "! pip install datasets transformers[sentencepiece]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Reusing dataset glue (/home/sgugger/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n", "Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-2b2682faffe74c3f.arrow\n", "Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-78d79fc323f0156c.arrow\n", "Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-801914374fb3c6ca.arrow\n" ] } ], "source": [ "from datasets import load_dataset\n", "from transformers import AutoTokenizer\n", "\n", "raw_datasets = load_dataset(\"glue\", \"mrpc\")\n", "checkpoint = \"bert-base-cased\"\n", "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n", "\n", "def tokenize_function(examples):\n", " return tokenizer(\n", " examples[\"sentence1\"], examples[\"sentence2\"], padding=\"max_length\", truncation=True, max_length=128\n", " )\n", "\n", "tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)\n", "tokenized_datasets = tokenized_datasets.remove_columns([\"idx\", \"sentence1\", \"sentence2\"])\n", "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n", "tokenized_datasets = tokenized_datasets.with_format(\"torch\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([16, 128])\n", "torch.Size([16, 128])\n", "torch.Size([16, 128])\n", "torch.Size([16, 128])\n", "torch.Size([16, 128])\n", "torch.Size([16, 128])\n", "torch.Size([16, 128])\n" ] } ], "source": [ "from torch.utils.data import DataLoader\n", "\n", "train_dataloader = DataLoader(tokenized_datasets[\"train\"], batch_size=16, shuffle=True)\n", "\n", "for step, batch in enumerate(train_dataloader):\n", " print(batch[\"input_ids\"].shape)\n", " if step > 5:\n", " break" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Reusing dataset glue (/home/sgugger/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n", "Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-8174fd92eed0af98.arrow\n", "Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-8c99fb059544bc96.arrow\n", "Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-e625eb72bcf1ae1f.arrow\n" ] } ], "source": [ "from datasets import load_dataset\n", "from transformers import AutoTokenizer\n", "\n", "raw_datasets = load_dataset(\"glue\", \"mrpc\")\n", "checkpoint = \"bert-base-cased\"\n", "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n", "\n", "def tokenize_function(examples):\n", " return tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True)\n", "\n", "tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)\n", "tokenized_datasets = tokenized_datasets.remove_columns([\"idx\", \"sentence1\", \"sentence2\"])\n", "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n", "tokenized_datasets = tokenized_datasets.with_format(\"torch\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([16, 83])\n", "torch.Size([16, 75])\n", "torch.Size([16, 81])\n", "torch.Size([16, 75])\n", "torch.Size([16, 80])\n", "torch.Size([16, 81])\n", "torch.Size([16, 81])\n" ] } ], "source": [ "from torch.utils.data import DataLoader\n", "from transformers import DataCollatorWithPadding\n", "\n", "data_collator = DataCollatorWithPadding(tokenizer)\n", "train_dataloader = DataLoader(\n", " tokenized_datasets[\"train\"], batch_size=16, shuffle=True, collate_fn=data_collator\n", ")\n", "\n", "for step, batch in enumerate(train_dataloader):\n", " print(batch[\"input_ids\"].shape)\n", " if step > 5:\n", " break" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "colab": { "name": "What is dynamic padding?", "provenance": [] } }, "nbformat": 4, "nbformat_minor": 4 }