notebooks/91_tf_xla_generate.ipynb (130 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Faster Text Generation with TensorFlow and XLA\n",
"\n",
"This notebook is a companion to the 🤗 [blog post with the same title](https://huggingface.co/blog/tf-xla-generate). \n",
"It is meant to illustrate how to use XLA with TensorFlow text generation.\n",
"\n",
"It contains two stand-alone examples, one for encoder-decoder models and another for decoder-only models.\n",
"\n",
"⚠️ If you are running this on colab, you might not have access to a GPU. The benefits of XLA are best observed with a GPU!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Preparing the environment\n",
"!pip install transformers>=4.21.0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Stand-alone TF XLA generate example for Encoder-Decoder Models.\n",
"\n",
"# Note: execution times are deeply dependent on hardware.\n",
"# If you have a machine with a powerful GPU, I highly recommend you to try this example there!\n",
"import time\n",
"import tensorflow as tf\n",
"from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM\n",
"\n",
"# 1. Load model and tokenizer\n",
"model_name = \"t5-small\"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
"\n",
"# 2. Prepare tokenization and generation arguments -- don't forget padding to avoid retracing!\n",
"tokenization_kwargs = {\"pad_to_multiple_of\": 32, \"padding\": True, \"return_tensors\": \"tf\"}\n",
"generation_kwargs = {\"num_beams\": 4, \"max_new_tokens\": 32}\n",
"\n",
"# 3. Create your XLA generate function a̶n̶d̶ ̶m̶a̶k̶e̶ ̶P̶y̶T̶o̶r̶c̶h̶ ̶e̶a̶t̶ ̶d̶u̶s̶t̶\n",
"# This is the only change with respect to original generate workflow!\n",
"xla_generate = tf.function(model.generate, jit_compile=True)\n",
"\n",
"# 4. Generate! Remember -- the first call will be slow, but all subsequent calls will be fast if you've done things right.\n",
"input_prompts = [\n",
" f\"translate English to {language}: I have four cats and three dogs.\" for language in [\"German\", \"French\", \"Romanian\"]\n",
"]\n",
"for input_prompt in input_prompts:\n",
" tokenized_inputs = tokenizer([input_prompt], **tokenization_kwargs)\n",
" start = time.time_ns()\n",
" generated_text = xla_generate(**tokenized_inputs, **generation_kwargs)\n",
" end = time.time_ns()\n",
" decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)\n",
" print(f\"Original prompt -- {input_prompt}\")\n",
" print(f\"Generated -- {decoded_text}\")\n",
" print(f\"Execution time -- {(end - start) / 1e6:.1f} ms\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Stand-alone TF XLA generate example for Decoder-Only Models.\n",
"\n",
"# Note: execution times are deeply dependent on hardware.\n",
"# If you have a machine with a powerful GPU, I highly recommend you to try this example there!\n",
"import time\n",
"import tensorflow as tf\n",
"from transformers import AutoTokenizer, TFAutoModelForCausalLM\n",
"\n",
"# 1. Load model and tokenizer\n",
"model_name = \"gpt2\"\n",
"# remember: decoder-only models need left-padding\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=\"left\", pad_token=\"</s>\")\n",
"model = TFAutoModelForCausalLM.from_pretrained(model_name)\n",
"\n",
"# 2. Prepare tokenization and generation arguments -- don't forget padding to avoid retracing!\n",
"tokenization_kwargs = {\"pad_to_multiple_of\": 32, \"padding\": True, \"return_tensors\": \"tf\"}\n",
"generation_kwargs = {\"num_beams\": 4, \"max_new_tokens\": 32}\n",
"\n",
"# 3. Create your XLA generate function a̶n̶d̶ ̶m̶a̶k̶e̶ ̶P̶y̶T̶o̶r̶c̶h̶ ̶e̶a̶t̶ ̶d̶u̶s̶t̶\n",
"# This is the only change with respect to original generate workflow!\n",
"xla_generate = tf.function(model.generate, jit_compile=True)\n",
"\n",
"# 4. Generate! Remember -- the first call will be slow, but all subsequent calls will be fast if you've done things right.\n",
"input_prompts = [f\"The best thing about {country} is\" for country in [\"Spain\", \"Japan\", \"Angola\"]]\n",
"for input_prompt in input_prompts:\n",
" tokenized_inputs = tokenizer([input_prompt], **tokenization_kwargs)\n",
" start = time.time_ns()\n",
" generated_text = xla_generate(**tokenized_inputs, **generation_kwargs)\n",
" end = time.time_ns()\n",
" decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)\n",
" print(f\"Original prompt -- {input_prompt}\")\n",
" print(f\"Generated -- {decoded_text}\")\n",
" print(f\"Execution time -- {(end - start) / 1e6:.1f} ms\\n\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.13 ('hf': venv)",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.8.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "84f3c8774ca1c74eb574ae1655a273850a12d5dbb694801a64998ecbefff8fe7"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}