course/videos/training_loop.ipynb (334 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook regroups the code sample of the video below, which is a part of the [Hugging Face course](https://huggingface.co/course)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form" }, "outputs": [ { "data": { "text/html": [ "<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/Dh9CL8fyG80?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#@title\n", "from IPython.display import HTML\n", "\n", "HTML('<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/Dh9CL8fyG80?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Install the Transformers and Datasets libraries to run this notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "! pip install datasets transformers[sentencepiece]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Reusing dataset glue (/home/sgugger/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n", "Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-8174fd92eed0af98.arrow\n", "Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-8c99fb059544bc96.arrow\n", "Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-e625eb72bcf1ae1f.arrow\n" ] } ], "source": [ "from datasets import load_dataset\n", "from transformers import AutoTokenizer, DataCollatorWithPadding\n", "\n", "raw_datasets = load_dataset(\"glue\", \"mrpc\")\n", "checkpoint = \"bert-base-cased\"\n", "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n", "\n", "def tokenize_function(examples):\n", " return tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True)\n", "\n", "tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)\n", "tokenized_datasets = tokenized_datasets.remove_columns([\"sentence1\", \"sentence2\", \"idx\"])\n", "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n", "tokenized_datasets.set_format(\"torch\")\n", "\n", "data_collator = DataCollatorWithPadding(tokenizer)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "\n", "train_dataloader = DataLoader(\n", " tokenized_datasets[\"train\"], shuffle=True, batch_size=8, collate_fn=data_collator\n", ")\n", "eval_dataloader = DataLoader(\n", " tokenized_datasets[\"validation\"], batch_size=8, collate_fn=data_collator\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'attention_mask': torch.Size([8, 63]), 'input_ids': torch.Size([8, 63]), 'labels': torch.Size([8]), 'token_type_ids': torch.Size([8, 63])}\n" ] } ], "source": [ "for batch in train_dataloader:\n", " break\n", "print({k: v.shape for k, v in batch.items()})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n", "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "from transformers import AutoModelForSequenceClassification\n", "\n", "checkpoint = \"bert-base-cased\"\n", "model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(0.7512, grad_fn=<NllLossBackward>) torch.Size([8, 2])\n" ] } ], "source": [ "outputs = model(**batch)\n", "print(outputs.loss, outputs.logits.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import AdamW\n", "\n", "optimizer = AdamW(model.parameters(), lr=5e-5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss = outputs.loss\n", "loss.backward()\n", "optimizer.step()\n", "\n", "# Don't forget to zero your gradients once your optimizer step is done!\n", "optimizer.zero_grad()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import get_scheduler\n", "\n", "num_epochs = 3\n", "num_training_steps = num_epochs * len(train_dataloader)\n", "lr_scheduler = get_scheduler(\n", " \"linear\",\n", " optimizer=optimizer,\n", " num_warmup_steps=0,\n", " num_training_steps=num_training_steps\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda\n" ] } ], "source": [ "import torch\n", "\n", "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "model.to(device)\n", "print(device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "optimizer = AdamW(model.parameters(), lr=5e-5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f6d89d483f35415abe98cd4a5e3ec580", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=1377.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "from tqdm.auto import tqdm\n", "\n", "progress_bar = tqdm(range(num_training_steps))\n", "\n", "model.train()\n", "for epoch in range(num_epochs):\n", " for batch in train_dataloader:\n", " batch = {k: v.to(device) for k, v in batch.items()}\n", " outputs = model(**batch)\n", " loss = outputs.loss\n", " loss.backward()\n", " \n", " optimizer.step()\n", " lr_scheduler.step()\n", " optimizer.zero_grad()\n", " progress_bar.update(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'accuracy': 0.8284313725490197, 'f1': 0.8809523809523808}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from datasets import load_metric\n", "\n", "metric= load_metric(\"glue\", \"mrpc\")\n", "model.eval()\n", "for batch in eval_dataloader:\n", " batch = {k: v.to(device) for k, v in batch.items()}\n", " with torch.no_grad():\n", " outputs = model(**batch)\n", " \n", " logits = outputs.logits\n", " predictions = torch.argmax(logits, dim=-1)\n", " metric.add_batch(predictions=predictions, references=batch[\"labels\"])\n", "\n", "metric.compute()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "colab": { "name": "Write your training loop in PyTorch", "provenance": [] } }, "nbformat": 4, "nbformat_minor": 4 }