notebooks/91_tf_xla_generate.ipynb (130 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Faster Text Generation with TensorFlow and XLA\n", "\n", "This notebook is a companion to the 🤗 [blog post with the same title](https://huggingface.co/blog/tf-xla-generate). \n", "It is meant to illustrate how to use XLA with TensorFlow text generation.\n", "\n", "It contains two stand-alone examples, one for encoder-decoder models and another for decoder-only models.\n", "\n", "⚠️ If you are running this on colab, you might not have access to a GPU. The benefits of XLA are best observed with a GPU!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Preparing the environment\n", "!pip install transformers>=4.21.0" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Stand-alone TF XLA generate example for Encoder-Decoder Models.\n", "\n", "# Note: execution times are deeply dependent on hardware.\n", "# If you have a machine with a powerful GPU, I highly recommend you to try this example there!\n", "import time\n", "import tensorflow as tf\n", "from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM\n", "\n", "# 1. Load model and tokenizer\n", "model_name = \"t5-small\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)\n", "\n", "# 2. Prepare tokenization and generation arguments -- don't forget padding to avoid retracing!\n", "tokenization_kwargs = {\"pad_to_multiple_of\": 32, \"padding\": True, \"return_tensors\": \"tf\"}\n", "generation_kwargs = {\"num_beams\": 4, \"max_new_tokens\": 32}\n", "\n", "# 3. Create your XLA generate function a̶n̶d̶ ̶m̶a̶k̶e̶ ̶P̶y̶T̶o̶r̶c̶h̶ ̶e̶a̶t̶ ̶d̶u̶s̶t̶\n", "# This is the only change with respect to original generate workflow!\n", "xla_generate = tf.function(model.generate, jit_compile=True)\n", "\n", "# 4. Generate! Remember -- the first call will be slow, but all subsequent calls will be fast if you've done things right.\n", "input_prompts = [\n", " f\"translate English to {language}: I have four cats and three dogs.\" for language in [\"German\", \"French\", \"Romanian\"]\n", "]\n", "for input_prompt in input_prompts:\n", " tokenized_inputs = tokenizer([input_prompt], **tokenization_kwargs)\n", " start = time.time_ns()\n", " generated_text = xla_generate(**tokenized_inputs, **generation_kwargs)\n", " end = time.time_ns()\n", " decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)\n", " print(f\"Original prompt -- {input_prompt}\")\n", " print(f\"Generated -- {decoded_text}\")\n", " print(f\"Execution time -- {(end - start) / 1e6:.1f} ms\\n\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Stand-alone TF XLA generate example for Decoder-Only Models.\n", "\n", "# Note: execution times are deeply dependent on hardware.\n", "# If you have a machine with a powerful GPU, I highly recommend you to try this example there!\n", "import time\n", "import tensorflow as tf\n", "from transformers import AutoTokenizer, TFAutoModelForCausalLM\n", "\n", "# 1. Load model and tokenizer\n", "model_name = \"gpt2\"\n", "# remember: decoder-only models need left-padding\n", "tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=\"left\", pad_token=\"</s>\")\n", "model = TFAutoModelForCausalLM.from_pretrained(model_name)\n", "\n", "# 2. Prepare tokenization and generation arguments -- don't forget padding to avoid retracing!\n", "tokenization_kwargs = {\"pad_to_multiple_of\": 32, \"padding\": True, \"return_tensors\": \"tf\"}\n", "generation_kwargs = {\"num_beams\": 4, \"max_new_tokens\": 32}\n", "\n", "# 3. Create your XLA generate function a̶n̶d̶ ̶m̶a̶k̶e̶ ̶P̶y̶T̶o̶r̶c̶h̶ ̶e̶a̶t̶ ̶d̶u̶s̶t̶\n", "# This is the only change with respect to original generate workflow!\n", "xla_generate = tf.function(model.generate, jit_compile=True)\n", "\n", "# 4. Generate! Remember -- the first call will be slow, but all subsequent calls will be fast if you've done things right.\n", "input_prompts = [f\"The best thing about {country} is\" for country in [\"Spain\", \"Japan\", \"Angola\"]]\n", "for input_prompt in input_prompts:\n", " tokenized_inputs = tokenizer([input_prompt], **tokenization_kwargs)\n", " start = time.time_ns()\n", " generated_text = xla_generate(**tokenized_inputs, **generation_kwargs)\n", " end = time.time_ns()\n", " decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)\n", " print(f\"Original prompt -- {input_prompt}\")\n", " print(f\"Generated -- {decoded_text}\")\n", " print(f\"Execution time -- {(end - start) / 1e6:.1f} ms\\n\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.13 ('hf': venv)", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.8.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "84f3c8774ca1c74eb574ae1655a273850a12d5dbb694801a64998ecbefff8fe7" } } }, "nbformat": 4, "nbformat_minor": 2 }