notebooks/text-classification.ipynb (769 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "id": "552ec552", "metadata": {}, "source": [ "# SetFit for Text Classification" ] }, { "cell_type": "markdown", "id": "3af7f258-5aaf-47c2-b81e-2f10fc349812", "metadata": {}, "source": [ "In this notebook, we'll learn how to do few-shot text classification with SetFit." ] }, { "cell_type": "markdown", "id": "c5604f73-f395-42cb-8082-9974a87ef9e9", "metadata": { "tags": [] }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "id": "26f09e23-2e1f-41f6-bb40-a30d447a0541", "metadata": {}, "source": [ "If you're running this Notebook on Colab or some other cloud platform, you will need to install the `setfit` library. Uncomment the following cell and run it:" ] }, { "cell_type": "code", "execution_count": null, "id": "712b96e8", "metadata": {}, "outputs": [], "source": [ "# %pip install setfit" ] }, { "cell_type": "markdown", "id": "64e4d3b4-93cd-4774-8055-35a00b11f483", "metadata": {}, "source": [ "To be able to share your model with the community, there are a few more steps to follow.\n", "\n", "First, you have to store your authentication token from the Hugging Face Hub (sign up [here](https://huggingface.co/join) if you haven't already!). To do so, execute the following cell and input an [access token](https://huggingface.co/docs/hub/security-tokens) associated with your account:" ] }, { "cell_type": "code", "execution_count": null, "id": "526a3b86-db3c-4c27-bb6c-eb39d73326f2", "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "\n", "notebook_login()" ] }, { "cell_type": "markdown", "id": "9309b7d5-1736-46be-b721-eb4ee4ab9e67", "metadata": {}, "source": [ "Then you need to install Git-LFS, which you can do by uncommenting and running following command:" ] }, { "cell_type": "code", "execution_count": null, "id": "5efb134e-fc40-42f2-b2a5-47112b6f2305", "metadata": {}, "outputs": [], "source": [ "# !apt install git-lfs" ] }, { "cell_type": "markdown", "id": "549981d3", "metadata": {}, "source": [ "Finally, you may need to configue Git on your system by providing details about who you are:" ] }, { "cell_type": "code", "execution_count": 1, "id": "18b950f1", "metadata": {}, "outputs": [], "source": [ "# !git config --global user.email \"you@example.com\"\n", "# !git config --global user.name \"Your Name\"" ] }, { "cell_type": "markdown", "id": "2b2a8fcd-46fe-43e4-835e-57b84964358a", "metadata": {}, "source": [ "This notebook is designed to work with any multiclass [text classification dataset](https://huggingface.co/models?pipeline_tag=text-classification&sort=downloads) and pretrained [Sentence Transformer](https://huggingface.co/models?library=sentence-transformers&sort=downloads) on the Hub. Change the values below to try a different dataset / model!" ] }, { "cell_type": "code", "execution_count": 2, "id": "41542e15-d211-45e9-b428-e1532c525f5b", "metadata": {}, "outputs": [], "source": [ "dataset_id = \"sst2\"\n", "model_id = \"sentence-transformers/paraphrase-mpnet-base-v2\"" ] }, { "cell_type": "markdown", "id": "3e756be8-3b60-4c86-aa1b-7ef78289b8e2", "metadata": {}, "source": [ "## Loading and sampling the dataset" ] }, { "cell_type": "markdown", "id": "6cac8eff-fc55-4514-aa3e-d4ea4315827e", "metadata": {}, "source": [ "We will use the 🤗 Datasets library to download the data, which can be done as follows:" ] }, { "cell_type": "code", "execution_count": 3, "id": "a478f539-3867-4fc1-94ef-02b6dcc1676f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration default\n", "Reusing dataset sst2 (/home/lewis/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "68a64e0c2c5a4bfbbb2d5a5bce46d340", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['idx', 'sentence', 'label'],\n", " num_rows: 67349\n", " })\n", " validation: Dataset({\n", " features: ['idx', 'sentence', 'label'],\n", " num_rows: 872\n", " })\n", " test: Dataset({\n", " features: ['idx', 'sentence', 'label'],\n", " num_rows: 1821\n", " })\n", "})" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from datasets import load_dataset\n", "\n", "dataset = load_dataset(dataset_id)\n", "dataset" ] }, { "cell_type": "markdown", "id": "03aa301d-51ec-4fe5-95c5-8a2e0aa2fb35", "metadata": {}, "source": [ "Most datasets on the Hub have many more labeled examples than those one encounters in few-shot settings. To simulate the effect of training on a limited number of examples, let's subsample the training set to have 8 labeled examples per class:" ] }, { "cell_type": "code", "execution_count": 4, "id": "ba671e5e-58f7-4d9e-aa82-8c0413b4a8df", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached shuffled indices for dataset at /home/lewis/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5/cache-2459d1a782cafb86.arrow\n", "Parameter 'function'=<function sample_dataset.<locals>.<lambda> at 0x7faf0a689630> of the transform datasets.arrow_dataset.Dataset.filter@2.0.1 couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8ea017c056a844d2bd4f916c5e05b4fb", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/68 [00:00<?, ?ba/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ddbcf6cf27184655b5d7bc99a2e63191", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/68 [00:00<?, ?ba/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "Dataset({\n", " features: ['idx', 'sentence', 'label'],\n", " num_rows: 16\n", "})" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from setfit import sample_dataset\n", "\n", "train_dataset = sample_dataset(dataset[\"train\"])\n", "train_dataset" ] }, { "cell_type": "markdown", "id": "285cf43e-2064-4da0-8910-a6e621ef2fbd", "metadata": {}, "source": [ "Here we have 16 total examples to train with since the `sst2` dataset has two classes (positive and negative). For evaluation, we'll use the validation split, since the test split of `sst2` is unlabeled:" ] }, { "cell_type": "code", "execution_count": 5, "id": "c609f71b-76b4-48d1-8024-08d16a713785", "metadata": {}, "outputs": [], "source": [ "eval_dataset = dataset[\"validation\"] " ] }, { "cell_type": "markdown", "id": "059bd547-0e39-43ab-adf0-5f5509217020", "metadata": {}, "source": [ "Okay, now we have the dataset, let's load and train a model!" ] }, { "cell_type": "markdown", "id": "37e7c839-1f06-4d35-aa34-6e13659db814", "metadata": {}, "source": [ "## Fine-tuning the model" ] }, { "cell_type": "markdown", "id": "78e8c41a", "metadata": {}, "source": [ "To train a SetFit model, the first thing to do is download a pretrained checkpoint from the Hub. We can do so by using the `from_pretrained()` method associated with the `SetFitModel` class:" ] }, { "cell_type": "code", "execution_count": 39, "id": "33661c9d-46d3-42eb-9b15-8a2bc49d7f6c", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.\n" ] } ], "source": [ "from setfit import SetFitModel\n", "\n", "model = SetFitModel.from_pretrained(model_id)" ] }, { "cell_type": "markdown", "id": "84e7521e-95ca-431a-8d7f-2f18e1de16ce", "metadata": {}, "source": [ "Here, we've downloaded a pretrained Sentence Transformer from the Hub and added a logistic classification head to the create the SetFit model. As indicated in the message, we need to train this model on some labeled examples. We can do so by using the `SetFitTrainer` class as follows:" ] }, { "cell_type": "code", "execution_count": 40, "id": "e44b7069-27b0-49ea-bc27-c44f94a98e2f", "metadata": {}, "outputs": [], "source": [ "from sentence_transformers.losses import CosineSimilarityLoss\n", "\n", "from setfit import SetFitTrainer\n", "\n", "trainer = SetFitTrainer(\n", " model=model,\n", " train_dataset=train_dataset,\n", " eval_dataset=eval_dataset,\n", " loss_class=CosineSimilarityLoss,\n", " num_iterations=20,\n", " column_mapping={\"sentence\": \"text\", \"label\": \"label\"},\n", ")" ] }, { "cell_type": "markdown", "id": "cbd3e642", "metadata": {}, "source": [ "The main arguments to notice in the trainer is the following:\n", "\n", "* `loss_class`: The loss function to use for contrastive learning with the Sentence Transformer body\n", "* `num_iterations`: The number of text pairs to generate for contrastive learning\n", "* `column_mapping`: The `SetFitTrainer` expects the inputs to be found in a `text` and `label` column. This mapping automatically formats the training and evaluation datasets for us." ] }, { "cell_type": "markdown", "id": "b6e3c5ae-c287-4936-b1ac-5eca10c7f39c", "metadata": {}, "source": [ "Now that we've created a trainer, we can train it!" ] }, { "cell_type": "code", "execution_count": 41, "id": "3d79e13b-37b1-4448-a7be-2bcc243b859d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Applying column mapping to training dataset\n", "***** Running training *****\n", " Num examples = 640\n", " Num epochs = 1\n", " Total optimization steps = 40\n", " Total train batch size = 16\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f3c4b2e3949b4908ab8a74cf748c4511", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Epoch: 0%| | 0/1 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3160d74da60244468eda3ca9b0d4393b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Iteration: 0%| | 0/40 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trainer.train()" ] }, { "cell_type": "markdown", "id": "e799f994", "metadata": {}, "source": [ "The final step is to compute the model's performance using the `evaluate()` method:" ] }, { "cell_type": "code", "execution_count": 42, "id": "453c11d0-a1e4-49c2-859a-cc70e033b4a6", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Applying column mapping to evaluation dataset\n", "***** Running evaluation *****\n" ] }, { "data": { "text/plain": [ "{'accuracy': 0.8772935779816514}" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "metrics = trainer.evaluate()\n", "metrics" ] }, { "cell_type": "markdown", "id": "f021b168-02d2-4cc8-942d-65d3b821e253", "metadata": {}, "source": [ "And once the model is trained, you can push it to the Hub:" ] }, { "cell_type": "code", "execution_count": null, "id": "c420c4b9-1552-45a5-888c-cdbb78f8e4fc", "metadata": { "scrolled": true }, "outputs": [], "source": [ "trainer.push_to_hub(\"my-awesome-setfit-model\")" ] }, { "cell_type": "markdown", "id": "02173d18-4874-4148-8789-90ac695717bc", "metadata": {}, "source": [ "You can now share this model with all your friends, family, favorite pets: they can all load it with the identifier `your-username/the-name-you-picked` so for instance:" ] }, { "cell_type": "code", "execution_count": 11, "id": "135ba8d2-ac13-4329-946a-04226a253d83", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/lewis/miniconda3/envs/setfit/lib/python3.10/site-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator LogisticRegression from version 1.1.1 when using version 1.1.3. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n", "https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ "array([1, 0])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from setfit import SetFitModel\n", "\n", "model = SetFitModel.from_pretrained(\"lewtun/my-awesome-setfit-model\")\n", "\n", "# Run inference\n", "preds = model([\"i loved the spiderman movie!\", \"pineapple on pizza is the worst 🤮\"])\n", "preds " ] }, { "cell_type": "markdown", "id": "12ae661d-2236-4eb3-8c52-a8fab714beb9", "metadata": {}, "source": [ "## Fine-tuning with a pure PyTorch model\n", "\n", "`setfit` also provides a pure PyTorch implementation of `SetFitModel`, where the head is a dense layer instead of a classifier from `scikit-learn`. This allows one to do backprop end-to-end and have more fine-grained control over the training process.\n", "\n", "To use the PyTorch model, we load a pretrained model with `use_differentiable_head=True` and specify the number of classes to include in the head:" ] }, { "cell_type": "code", "execution_count": 48, "id": "58a69d72-edae-45a8-a423-702062ce75a9", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/lewis/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5/cache-f8c1ea1d1209fce8.arrow\n", "model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.\n" ] } ], "source": [ "from setfit import SetFitModel\n", "\n", "num_classes = len(train_dataset.unique(\"label\"))\n", "model = SetFitModel.from_pretrained(model_id, use_differentiable_head=True, head_params={\"out_features\": num_classes})" ] }, { "cell_type": "markdown", "id": "2959a761-24cc-4d65-a90c-8a12b61d383c", "metadata": {}, "source": [ "As before, we instantiate the trainer:" ] }, { "cell_type": "code", "execution_count": 49, "id": "e3d4038c-4426-45d0-b7a4-d1e89019df7a", "metadata": {}, "outputs": [], "source": [ "trainer = SetFitTrainer(\n", " model=model,\n", " train_dataset=train_dataset,\n", " eval_dataset=eval_dataset,\n", " loss_class=CosineSimilarityLoss,\n", " num_iterations=20,\n", " column_mapping={\"sentence\": \"text\", \"label\": \"label\"},\n", ")" ] }, { "cell_type": "markdown", "id": "8065c187-f69f-4220-912d-036d632de8ac", "metadata": {}, "source": [ "Next, we freeze the weights of the final layer and apply contrastive learning:" ] }, { "cell_type": "code", "execution_count": 50, "id": "362fdd4d-45aa-41e0-a5eb-1b7e9ef69edf", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Applying column mapping to training dataset\n", "***** Running training *****\n", " Num examples = 640\n", " Num epochs = 1\n", " Total optimization steps = 40\n", " Total train batch size = 16\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "474d7e253ded4c96b940cdac55b5537f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Epoch: 0%| | 0/1 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "585e7caf1ac54bd5aa8454c7946f6138", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Iteration: 0%| | 0/40 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trainer.freeze()\n", "trainer.train(body_learning_rate=1e-5, num_epochs=1)" ] }, { "cell_type": "markdown", "id": "f970fd13-a19a-4ab7-ba8f-f9dc63c21c37", "metadata": {}, "source": [ "Note that here we can specify the learning rate for the model's body - we find that small values in 1e-5 range work well for this step.\n", "\n", "Now that the model body is tuned, we can unfreeze the head and train it:" ] }, { "cell_type": "code", "execution_count": 51, "id": "690cae61-e6ff-4f27-9146-4a6d4fd68e0c", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Applying column mapping to training dataset\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3fba50be163e44b8a29dc132b2e75f68", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Epoch: 0%| | 0/50 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trainer.unfreeze(keep_body_frozen=True)\n", "trainer.train(learning_rate=1e-2, num_epochs=50)" ] }, { "cell_type": "markdown", "id": "e8db2267-dcd6-4f3e-a19a-3249b9317741", "metadata": {}, "source": [ "Note that a larger learning rate is used when training the head. We recommend using values in the 1e-2 range. Now that the model is trained, we can evaluate it as usual:" ] }, { "cell_type": "code", "execution_count": 52, "id": "08acb60c-a1ae-4d74-ad96-c3e2a9bdf6ac", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Applying column mapping to evaluation dataset\n", "***** Running evaluation *****\n" ] }, { "data": { "text/plain": [ "{'accuracy': 0.8577981651376146}" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.evaluate()" ] }, { "cell_type": "markdown", "id": "ba50c71c-923a-4f63-82ab-d705182dcc0b", "metadata": {}, "source": [ "Nice! This is comparable to the results found with the `scikit-learn` head." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" }, "vscode": { "interpreter": { "hash": "1a53731e204626af339a5238c341a3f8c4bfd7cb5ccdda48ca3fe8366eef4175" } } }, "nbformat": 4, "nbformat_minor": 5 }