course/videos/semantic_search.ipynb (183 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook regroups the code sample of the video below, which is a part of the [Hugging Face course](https://huggingface.co/course)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form" }, "outputs": [ { "data": { "text/html": [ "<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/OATCgQtNX2o?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#@title\n", "from IPython.display import HTML\n", "\n", "HTML('<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/OATCgQtNX2o?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Install the Transformers and Datasets libraries to run this notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "! pip install datasets transformers[sentencepiece]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from transformers import AutoTokenizer, AutoModel\n", "\n", "sentences = [\n", " \"I took my dog for a walk\",\n", " \"Today is going to rain\",\n", " \"I took my cat for a walk\",\n", "]\n", "\n", "model_ckpt = \"sentence-transformers/all-MiniLM-L6-v2\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n", "model = AutoModel.from_pretrained(model_ckpt)\n", "\n", "encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors=\"pt\")\n", "\n", "with torch.no_grad():\n", " model_output = model(**encoded_input)\n", " \n", " \n", "token_embeddings = model_output.last_hidden_state\n", "print(f\"Token embeddings shape: {token_embeddings.size()}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "\n", "def mean_pooling(model_output, attention_mask):\n", " token_embeddings = model_output.last_hidden_state\n", " input_mask_expanded = (\n", " attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n", " )\n", " return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(\n", " input_mask_expanded.sum(1), min=1e-9\n", " )\n", "\n", "\n", "sentence_embeddings = mean_pooling(model_output, encoded_input[\"attention_mask\"])\n", "# Normalize the embeddings\n", "sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)\n", "print(f\"Sentence embeddings shape: {sentence_embeddings.size()}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from sklearn.metrics.pairwise import cosine_similarity\n", "\n", "sentence_embeddings = sentence_embeddings.detach().numpy()\n", "\n", "scores = np.zeros((sentence_embeddings.shape[0], sentence_embeddings.shape[0]))\n", "\n", "for idx in range(sentence_embeddings.shape[0]):\n", " scores[idx, :] = cosine_similarity([sentence_embeddings[idx]], sentence_embeddings)[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "squad = load_dataset(\"squad\", split=\"validation\").shuffle(seed=42).select(range(100))\n", "\n", "\n", "def get_embeddings(text_list):\n", " encoded_input = tokenizer(\n", " text_list, padding=True, truncation=True, return_tensors=\"pt\"\n", " )\n", " encoded_input = {k: v for k, v in encoded_input.items()}\n", " with torch.no_grad():\n", " model_output = model(**encoded_input)\n", " return mean_pooling(model_output, encoded_input[\"attention_mask\"])\n", "\n", "\n", "squad_with_embeddings = squad.map(\n", " lambda x: {\"embeddings\": get_embeddings(x[\"context\"]).cpu().numpy()[0]}\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "squad_with_embeddings.add_faiss_index(column=\"embeddings\")\n", "\n", "question = \"Who headlined the halftime show for Super Bowl 50?\"\n", "question_embedding = get_embeddings([question]).cpu().detach().numpy()\n", "\n", "scores, samples = squad_with_embeddings.get_nearest_examples(\n", " \"embeddings\", question_embedding, k=3\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "colab": { "name": "Text embeddings & semantic search", "provenance": [] } }, "nbformat": 4, "nbformat_minor": 4 }