notebooks/53_constrained_beam_search.ipynb (513 lines of code) (raw):

{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "53_constrained_beam_search.ipynb", "provenance": [], "collapsed_sections": [], "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "<a href=\"https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/53_constrained_beam_search.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" ] }, { "cell_type": "markdown", "metadata": { "id": "Vp3XPuaTu9jl" }, "source": [ "\n", "# Guiding Text Generation with Constrained Beam Search in 🤗 Transformers" ] }, { "cell_type": "markdown", "metadata": { "id": "KxLvv6UaPa33" }, "source": [ "## **Introduction**\n", "\n", "This blog post assumes that the reader is familiar with text generation methods using the different variants of beam search, as explained in in the blog post: [\"How to generate text: using different decoding methods for language generation with Transformers\"](https://huggingface.co/blog/how-to-generate)\n", "\n", "Unlike ordinary beam search, **constrained** beam search allows us to exert control over the output of text generation. This is useful because, sometimes, we know exactly what we want inside the output. For example, in a Neural Machine Translation task, we might know which words must be included in the final translation with a dictionary lookup. Sometimes, generation outputs that are almost equally possible to a language model might not be equally deseriable for the end-user due to the particular context. Both of these situations could be solved by allowing the users to tell the model which words must be included in the end output. \n", "\n", "### **Why It's Difficult**\n", "\n", "However, this is actually a very non-trivial problem. This is because the task requires us to force the generation of certain subsequences *somewhere* in the final output, at *some point* during the generation. \n", "\n", "Let's say that we're want to generate a sentence `S` that has to include the phrase $p_1=\\{ t_1, t_2 \\}$ with tokens $t_1, t_2$ in order. Let's define the expected sentence $S$ as:\n", "\n", "$S_{expected} = \\{ s_1, s_2, ..., s_k, t_1, t_2, s_{k+1}, ..., s_n \\}$ \n", "\n", "The problem is that beam search generates the sequence *token-by-token*. Though not entirely accurate, one can think of beam search as the function $B(\\mathbf{s}_{0:i}) = s_{i+1}$, where it looks at the currently generated sequence of tokens from $0$ to $i$ then predicts the next token at $i+1$. But how can this function know, at an arbitrary step $i < k$, that the tokens must be generated at some future step $k$? Or when it's at the step $i=k$, how can it know for sure that this is the best spot to force the tokens, instead of some future step $i>k$?\n", "\n", "![Why constraints are hard](https://raw.githubusercontent.com/cwkeam/scientific-images/main/why_constraints_are_hard.png)\n", "\n", "\n", "And what if you have multiple constraints with varying requirements? What if you want to force the phrase $p_1=\\{t_1, t_2\\}$ *and* also the phrase $p_2=\\{ t_3, t_4, t_5, t_6\\}$? What if you want the model to **choose between** the two phrases? What if we want to force the phrase $p_1$ and force just one phrase among the list of phrases $\\{p_{21}, p_{22}, p_{23}\\}$? \n", "\n", "The above examples are actually very reasonable use-cases, as it will be shown below, and the new constrained beam search feature allows for all of them!\n", "\n", "This post will quickly go over what the new ***constrained beam search*** feature can do for you, and then go into deeper details about how it works under the hood." ] }, { "cell_type": "markdown", "source": [ "## **Example 1: Forcing a Word**\n", "\n", "Let's say we're trying to translate `\"How old are you?\"` to German. \n", "\n", "`\"Wie alt bist du?\"` is what you'd say in an informal setting, and `\"Wie alt sind Sie?\"` is what \n", "you'd say in a formal setting.\n", "\n", "And depending on the context, we might want one form of formality over the other, but how do we tell the model that?\n", "\n", "### **Traditional Beam Search**\n", "\n", "Here's how we would do text translation in the ***traditional beam search setting.***\n" ], "metadata": { "id": "5-16ivOm9RNJ" } }, { "cell_type": "code", "source": [ "!pip install -q git+https://github.com/huggingface/transformers.git" ], "metadata": { "id": "cZv8Gwja8PqJ" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"t5-base\")\n", "model = AutoModelForSeq2SeqLM.from_pretrained(\"t5-base\")\n", "\n", "encoder_input_str = \"translate English to German: How old are you?\"\n", "\n", "input_ids = tokenizer(encoder_input_str, return_tensors=\"pt\").input_ids\n", "\n", "outputs = model.generate(\n", " input_ids,\n", " num_beams=10,\n", " num_return_sequences=1,\n", " no_repeat_ngram_size=1,\n", " remove_invalid_values=True,\n", ")\n", "\n", "print(\"Output:\\n\" + 100 * '-')\n", "print(tokenizer.decode(outputs[0], skip_special_tokens=True))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Dt7R0ZIh9hCn", "outputId": "e8159e95-7b03-4a31-9067-04022e5eb0c2" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Output:\n", "----------------------------------------------------------------------------------------------------\n", "Wie alt bist du?\n" ] } ] }, { "cell_type": "markdown", "metadata": { "id": "Si4GyYhOQMzi" }, "source": [ "\n", "### **With Constrained Beam Search**\n", "\n", "But what if we knew that we wanted a formal output instead of the informal one? What if we were able to know from prior knowledge what the generation must include and we were able to *inject it* into the generation?\n", "\n", "The following is what is possible now with the `force_words_ids` keyword argument to `model.generate()`:" ] }, { "cell_type": "code", "metadata": { "id": "XbzZ_IVTtoQe", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "12b255ae-04c4-4cf1-ec40-042bd5effad3" }, "source": [ "tokenizer = AutoTokenizer.from_pretrained(\"t5-base\")\n", "model = AutoModelForSeq2SeqLM.from_pretrained(\"t5-base\")\n", "\n", "encoder_input_str = \"translate English to German: How old are you?\"\n", "\n", "force_words = [\"Sie\"]\n", "\n", "input_ids = tokenizer(encoder_input_str, return_tensors=\"pt\").input_ids\n", "force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids\n", "\n", "outputs = model.generate(\n", " input_ids,\n", " force_words_ids=force_words_ids,\n", " num_beams=5,\n", " num_return_sequences=1,\n", " no_repeat_ngram_size=1,\n", " remove_invalid_values=True,\n", ")\n", "\n", "\n", "print(\"Output:\\n\" + 100 * '-')\n", "print(tokenizer.decode(outputs[0], skip_special_tokens=True))" ], "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Output:\n", "----------------------------------------------------------------------------------------------------\n", "Wie alt sind Sie?\n" ] } ] }, { "cell_type": "markdown", "source": [ "As you can see, we were able to guide the generation with prior knowledge about our desired output. Previously we would've had to generate a bunch of possible outputs, then filter the ones that fit our requirement. Now we can do that at the generation stage." ], "metadata": { "id": "B7bJ8oUO0m17" } }, { "cell_type": "markdown", "source": [ "## **Example 2: Disjunctive Constraints**\n", "\n", "We mentioned above about a use-case where we know which words we want included in the final output. An example of this might be using a dictionary lookup during neural machine translation.\n", "\n", "But what if we don't know which *word forms* to use, where we'd want outputs like `[\"raining\", \"rained\", \"rains\", ...]` to be equally possible? In a more general sense, there are always cases when we don't want the *exact word verbatim*, letter by letter, and might be open to other related possibilities too.\n", "\n", "Constraints that allow for this behavior are ***Disjunctive Constraints***, which allow the user to input a list of words, whose purpose is to guide the generation such that the final output must contain just *at least one* among the list of words. \n", "\n", "Here's an example that uses a mix of the above two types of constraints: " ], "metadata": { "id": "aYETQ8NEg_6e" } }, { "cell_type": "code", "metadata": { "id": "ue2kOQhXTAMU", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "14b65681-3720-4e48-ed50-85598c12c608" }, "source": [ "from transformers import GPT2LMHeadModel, GPT2Tokenizer\n", "\n", "model = GPT2LMHeadModel.from_pretrained(\"gpt2\")\n", "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n", "\n", "force_word = \"scared\"\n", "force_flexible = [\"scream\", \"screams\", \"screaming\", \"screamed\"]\n", "\n", "force_words_ids = [\n", " tokenizer([force_word], add_prefix_space=True, add_special_tokens=False).input_ids,\n", " tokenizer(force_flexible, add_prefix_space=True, add_special_tokens=False).input_ids,\n", "]\n", "\n", "starting_text = [\"The soldiers\", \"The child\"]\n", "\n", "input_ids = tokenizer(starting_text, return_tensors=\"pt\").input_ids\n", "\n", "outputs = model.generate(\n", " input_ids,\n", " force_words_ids=force_words_ids,\n", " num_beams=10,\n", " num_return_sequences=1,\n", " no_repeat_ngram_size=1,\n", " remove_invalid_values=True,\n", ")\n", "\n", "\n", "print(\"Output:\\n\" + 100 * '-')\n", "print(tokenizer.decode(outputs[0], skip_special_tokens=True))\n", "print(tokenizer.decode(outputs[1], skip_special_tokens=True))\n" ], "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Output:\n", "----------------------------------------------------------------------------------------------------\n", "The soldiers, who were all scared and screaming at each other as they tried to get out of the\n", "The child was taken to a local hospital where she screamed and scared for her life, police said.\n" ] } ] }, { "cell_type": "markdown", "source": [ "As you can see, the first output used `\"screaming\"`, the second output used `\"screamed\"`, and both used `\"scared\"` verbatim. The list to choose from `[\"screaming\", \"screamed\", ...]` doesn't have to be word forms; this can satisfy any use-case where we need just one from a list of words." ], "metadata": { "id": "6lulCgAHjqYA" } }, { "cell_type": "markdown", "metadata": { "id": "g8DnXZ1WiuNd" }, "source": [ "## **Traditional Beam search**\n", "\n", "The following is an example of traditional **beam search**, taken from a previous [blog post](https://huggingface.co/blog/how-to-generate):\n", "\n", "\n", "![Beam search](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/beam_search.png)\n", "\n", "Unlike greedy search, beam search works by keeping a longer list of hypotheses. In the above picture, we have displayed 3 next possible tokens at each possible step in the generation.\n", "\n", "Here's another way to look at the first step of the beam search for the above example, in the case of `num_beams=3`:\n", "\n", "![Beam search step 1](https://raw.githubusercontent.com/cwkeam/scientific-images/main/beam_1.jpg)\n", "\n", "Instead of only choosing `\"The dog\"` like what a greedy search would do, a beam search would allow *further consideration* of `\"The nice\"` and `\"The car\"`.\n", "\n", "In the next step, we consider the next possible tokens for each of the three branches we created in the previous step.\n", "\n", "![Beam search step 2](https://raw.githubusercontent.com/cwkeam/scientific-images/main/beam_2.jpg)\n", "\n", "Though we end up *considering* significantly more than `num_beams` outputs, we reduce them down to `num_beams` at the end of the step. We can't just keep branching out, then the number of `beams` we'd have to keep track of would be $beams^{n}$ for $n$ steps, which becomes very large very quickly ($10$ beams after $10$ steps is $10,000,000,000$ beams!). \n", "\n", "For the rest of the generation, we repeat the above step until the ending criteria has been met, like generating the `<eos>` token or reaching `max_length`, for example. Branch out, rank, reduce, and repeat.\n", "\n", "\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "source": [ "\n", "## **Constrained Beam Search**\n", "\n", "Constrained beam search attempts to fulfill the constraints by *injecting* the desired tokens at every step of the generation. \n", "\n", "Let's say that we're trying to force the phrase `\"is fast\"` in the generated output. \n", "\n", "In the traditional beam search setting, we find the top `k` most probable next tokens at each branch and append them for consideration. In the constrained setting we actually do the same, but also append the tokens that will take us *closer to fulfilling our constraints*. Here's a demonstration:\n", "\n", "\n", "![Constrained Beam Search Step 1](https://raw.githubusercontent.com/cwkeam/scientific-images/main/cbeam_1.jpg)\n", "\n", "On top of the usual high-probability next tokens like `\"dog\"` and `\"nice\"`, we force the token `\"is\"` in order to get us closer to fulfilling our constraint of `\"is fast\"`.\n", "\n", "For the next step, the branched-out candidates below are mostly the same as that of traditional beam search. But like the above example, constrained beam search adds onto the existing candidates by forcing the constraints at each new branch:\n", "\n", "![Constrained Beam Search Step 2](https://raw.githubusercontent.com/cwkeam/scientific-images/main/cbeam_2.jpg)\n", "\n", "### **Banks**\n", "\n", "Before we talk about the next step, we need to think about the resulting undesirable behavior we can see in the above step. \n", "\n", "The problem with naively just forcing the desired phrase `\"is fast\"` in the output is that, most of the time, you'd end up with nonsensical outputs like `\"The is fast\"` above. This is actually what makes this a nontrivial problem to solve. A deeper discussion about the complexities of solving this problem can be found in the [original feature request issue](https://github.com/huggingface/transformers/issues/14081#issuecomment-1004479944) that was raised in `huggingface/transformers`.\n", "\n", "Banks solve this problem by creating a *balance* between fulfilling the constraints and creating sensible output. \n", "\n", "Bank $n$ refers to the ***list of beams that have made $n$ steps progress in fulfilling the constraints***. After sorting all the possible beams into their respective banks, we do a round-robin selection. With the above example, we'd select the most probable output from Bank 2, then most probable from Bank 1, one from Bank 0, the second most probable from Bank 2, the second most probable from Bank 1, and so forth. Since we're using `num_beams=3`, we just do the above process three times to end up with `[\"The is fast\", \"The dog is\", \"The dog and\"]`.\n", "\n", "This way, even though we're *forcing* the model to consider the branch where we've manually appended the desired token, we still keep track of other high-probable sequences that probably make more sense. Even though `\"The is fast\"` fulfills our constraint completely, it's not a very sensible phrase. Luckily, we have `\"The dog is\"` and `\"The dog and\"` to work with in future steps, which hopefully will lead to more sensible outputs later on.\n", "\n", "This behavior is demonstrated in the third step of the above example:\n", "\n", "![Constrained Beam Search Step 3](https://raw.githubusercontent.com/cwkeam/scientific-images/main/cbeam_3.jpg)\n", "\n", "Notice how `\"The is fast\"` doesn't require any manual appending of constraint tokens since it's already fulfilled (i.e. already contains the phrase `\"is fast\"`). Also notice how beams like `\"The dog is slow\"` or `\"The dog is mad\"` are actually in Bank 0, since, although it includes the token `\"is\"`, it must restart from the beginning in order to generate `\"is fast\"`. By appending something like `\"slow\"` after `\"is\"`, it has effectively *reset its progress*. \n", "\n", "And finally notice how we ended up at a sensible output that contains our constraint phrase: `\"The dog is fast\"`! \n", "\n", "We were worried in the beginning because just blindly appending the desired tokens led to nonsensicle phrases like `\"The is fast\"`. However, with the use of round-robin selection from banks, we implicitly ended up getting rid of nonsensical outputs in preference for the more sensible outputs. " ], "metadata": { "id": "eIaZfWiN390B" } }, { "cell_type": "markdown", "source": [ "## **More About `Constraint` Classes and Custom Constraints**\n", "\n", "The main takeaway from the explanation can be summarized as the following. At every step, we keep pestering the model to consider the tokens that fulfill our constraints, all the while keeping track of beams that don't, until we end up with reasonably high probability sequences that contain our desired phrases. \n", "\n", "So a principled way to design this implementation was to represent each constraint as a `Constraint` object, whose purpose was to keep track of its progress and tell the beam search which tokens to generate next. Although we have provided the keyword argument `force_words_ids` for `model.generate()`, the following is what actually happens in the back-end:\n", "\n" ], "metadata": { "id": "Lz11vuE-Aiqp" } }, { "cell_type": "code", "source": [ "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PhrasalConstraint\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"t5-base\")\n", "model = AutoModelForSeq2SeqLM.from_pretrained(\"t5-base\")\n", "\n", "encoder_input_str = \"translate English to German: How old are you?\"\n", "\n", "constraints = [\n", " PhrasalConstraint(\n", " tokenizer(\"Sie\", add_special_tokens=False).input_ids\n", " )\n", "]\n", "\n", "input_ids = tokenizer(encoder_input_str, return_tensors=\"pt\").input_ids\n", "\n", "\n", "outputs = model.generate(\n", " input_ids,\n", " constraints=constraints,\n", " num_beams=10,\n", " num_return_sequences=1,\n", " no_repeat_ngram_size=1,\n", " remove_invalid_values=True,\n", ")\n", "\n", "\n", "print(\"Output:\\n\" + 100 * '-')\n", "print(tokenizer.decode(outputs[0], skip_special_tokens=True))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-OrgGyar-rTy", "outputId": "eb5f4440-07b7-4ab8-b2fc-534c52042ec5" }, "execution_count": 3, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Output:\n", "----------------------------------------------------------------------------------------------------\n", "Wie alt sind Sie?\n" ] } ] }, { "cell_type": "markdown", "source": [ "In order to design your own unique constraints, you can define one yourself and input it into the `constraints` keyword argument. You just have to create a sub-class of the `Constraint` abstract interface class and follow its requirements. More information about this can be found in the definition of `Constraint` found [here](https://github.com/huggingface/transformers/blob/master/src/transformers/generation_beam_constraints.py).\n", "\n", "Some unique ideas (not yet implemented; maybe you can give it a try!) include constraints like `OrderedConstraints`, `TemplateConstraints` that may be added further down the line. Currently, the generation is fulfilled by including the sequences, wherever in the output. For example, a previous example had one sequence with scared -> screaming and the other with screamed -> scared. `OrderedConstraints` could allow the user to specify the order in which these constraints are fulfilled. \n", "\n", "`TemplateConstraints` could allow for a more niche use of the feature, where the objective can be something like:\n", "\n", "```python\n", "starting_text = \"The woman\"\n", "template = [\"the\", \"\", \"School of\", \"\", \"in\"]\n", "\n", "possible_outputs == [\n", " \"The woman attended the Ross School of Business in Michigan.\",\n", " \"The woman was the administrator for the Harvard School of Business in MA.\"\n", "]\n", "```\n", "\n", "or:\n", "```python\n", "starting_text = \"The woman\"\n", "template = [\"the\", \"\", \"\", \"University\", \"\", \"in\"]\n", "\n", "possible_outputs == [\n", " \"The woman attended the Carnegie Mellon University in Pittsburgh.\",\n", "]\n", "impossible_outputs == [\n", " \"The woman attended the Harvard University in MA.\"\n", "]\n", "``` \n", "\n", "or if the user does not care about the number of tokens that can go in between two words, then one can just use `OrderedConstraint`.\n" ], "metadata": { "id": "cMbHnbcBDCjz" } }, { "cell_type": "markdown", "source": [ "## **Conclusion**\n", "\n", "Constrained beam search gives us a flexible means to inject external knowledge and requirements into text generation. Previously, there was no easy way to tell the model to 1. include a list of sequences where 2. some of which are optional and some are not, such that 3. they're generated *somewhere* in the sequence at reasonable respective positions. Now, with a mix of different subclasses of `Constraint` objects, we can have full control over our generation! \n", "\n", "This new feature is based mainly on the following papers:\n", "\n", " - [Guided Open Vocabulary Image Captioning with Constrained Beam Search](https://arxiv.org/abs/1612.00576)\n", " - [Fast Lexically Constrained Decoding with Dynamic Beam Allocation for Neural Machine Translation](https://arxiv.org/abs/1804.06609)\n", " - [Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting](https://aclanthology.org/N19-1090/)\n", " - [Guided Generation of Cause and Effect](https://arxiv.org/pdf/2107.09846.pdf)\n", "\n", "Many new research papers recently, like the ones above, are exploring ways of using external knowledge (e.g. KGs, KBs) to guide the outputs of large deep learning models. Hopefully this constrained beam search feature becomes another effective way to achieve this purpose.\n", "\n", "Thanks to everybody that gave guidance for this feature contribution: Patrick von Platen for being involved from the [initial issue](https://github.com/huggingface/transformers/issues/14081) to the [final PR](https://github.com/huggingface/transformers/pull/15761), and Narsil Patry, for providing detailed feedback on the code.\n", "\n", "*Thumbnail of this post uses an icon with the attribution: <a href=\"https://www.flaticon.com/free-icons/shorthand\" title=\"shorthand icons\">Shorthand icons created by Freepik - Flaticon</a>*\n", "\n" ], "metadata": { "id": "CZ2GyB2NFiSm" } } ] }