assisted_decoding/assisted_decoding_70B_3B.ipynb (353 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "hoB6HZTw4fvA" }, "source": [ "# Speedup Llama 3 70B with Speculative Decoding\n", "\n", "In this guide, we'll show you how to implement speculative decoding with Llama 3.1 70B model (base) and Llama 3.2 3B model (assistant). Transformers has a `generate` API where we pass an `assistant_model` to enable speculative decoding. By the end, you'll see how this technique can significantly speed up text generation (upto 2x), making your workflows faster and more efficient." ] }, { "cell_type": "markdown", "metadata": { "id": "8h73s8MaxqrK" }, "source": [ "## Imports and Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NaWzzUs5yfwo" }, "outputs": [], "source": [ "!pip install -Uq transformers accelerate" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6Kp5KI7qxsXi" }, "outputs": [], "source": [ "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "import time\n", "import torch\n", "from tqdm import tqdm" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "q5lpixXsrJQb" }, "outputs": [], "source": [ "# supress the warning in the notebook\n", "import logging\n", "import warnings\n", "logging.getLogger('transformers').setLevel(logging.ERROR)\n", "warnings.filterwarnings('ignore')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IUEYY5j71ph0" }, "outputs": [], "source": [ "# base models\n", "checkpoint = \"meta-llama/Meta-Llama-3.1-70B\" # <-- Larger Model\n", "assistant_checkpoint = \"meta-llama/Llama-3.2-3B\" # <-- Smaller Model" ] }, { "cell_type": "markdown", "metadata": { "id": "FyUZ7dpBmVjL" }, "source": [ "## Prepare Models" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LihenVlfgsuR" }, "outputs": [], "source": [ "model = AutoModelForCausalLM.from_pretrained(\n", " checkpoint,\n", " device_map=\"auto\",\n", " torch_dtype=torch.bfloat16,\n", ")\n", "assistant_model = AutoModelForCausalLM.from_pretrained(\n", " assistant_checkpoint,\n", " device_map=\"auto\",\n", " torch_dtype=torch.bfloat16,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "yE5QDkOOmZ64" }, "source": [ "### Prepare Inputs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "na-Qefklg0nm" }, "outputs": [], "source": [ "prompt = \"Alice and Bob\"\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n", "model_inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)" ] }, { "cell_type": "markdown", "metadata": { "id": "QTXUxJlWmieJ" }, "source": [ "### Benchmark the text generation speed\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nZQr3rSJg-eO" }, "outputs": [], "source": [ "max_new_tokens = 256\n", "num_iterations = 10" ] }, { "cell_type": "markdown", "metadata": { "id": "9t9jh24rrdYb" }, "source": [ "### Greedy Decoding" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mP8co683-T2M" }, "outputs": [], "source": [ "# warmup\n", "for _ in range(2):\n", " model.generate(\n", " **model_inputs,\n", " do_sample=False,\n", " assistant_model=assistant_model,\n", " max_new_tokens=max_new_tokens,\n", " eos_token_id=-1,\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "O482aPT-47xH" }, "outputs": [], "source": [ "# without assistance\n", "print(\"🐘 Big Model Generation\")\n", "duration = 0\n", "for _ in tqdm(range(num_iterations)):\n", " start = time.time()\n", " outputs = model.generate(\n", " **model_inputs,\n", " do_sample=False,\n", " max_new_tokens=256,\n", " eos_token_id=-1,\n", " )\n", " end = time.time()\n", " duration = duration + (end-start)\n", "\n", "print(f\"\\nThroughput: {(num_iterations * max_new_tokens) / (duration):.4f} (tokens/sec)\")\n", "\n", "\n", "# with assistance\n", "print(\"\\n\\nšŸ¤ Big Model Generation with Assistance\")\n", "duration = 0\n", "for _ in tqdm(range(10)):\n", " start = time.time()\n", " outputs = model.generate(\n", " **model_inputs,\n", " do_sample=False,\n", " max_new_tokens=256,\n", " eos_token_id=-1,\n", " assistant_model=assistant_model,\n", " )\n", " end = time.time()\n", " duration = duration + (end-start)\n", "\n", "print(f\"\\nThroughput: {(num_iterations * max_new_tokens) / (duration):.4f} (tokens/sec)\")" ] }, { "cell_type": "markdown", "metadata": { "id": "fu8FmSpUrgU7" }, "source": [ "### Multinomial Decoding" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wd6gbarl8GTQ" }, "outputs": [], "source": [ "# warmup\n", "for _ in range(2):\n", " model.generate(\n", " **model_inputs,\n", " do_sample=True,\n", " temperature=0.2,\n", " assistant_model=assistant_model,\n", " max_new_tokens=max_new_tokens,\n", " eos_token_id=-1,\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "asU-_sjOgeJP" }, "outputs": [], "source": [ "# without assistance\n", "print(\"🐘 Big Model Generation (multinomial)\")\n", "duration = 0\n", "for _ in tqdm(range(num_iterations)):\n", " start = time.time()\n", " outputs = model.generate(\n", " **model_inputs,\n", " do_sample=True,\n", " temperature=0.2,\n", " max_new_tokens=256,\n", " eos_token_id=-1,\n", " )\n", " end = time.time()\n", " duration = duration + (end-start)\n", "\n", "print(f\"\\nThroughput: {(num_iterations * max_new_tokens) / (duration):.4f} (tokens/sec)\")\n", "\n", "\n", "# with assistance\n", "print(\"\\n\\nšŸ¤ Big Model Generation with Assistance (multinomial)\")\n", "duration = 0\n", "for _ in tqdm(range(10)):\n", " start = time.time()\n", " outputs = model.generate(\n", " **model_inputs,\n", " do_sample=True,\n", " temperature=0.2,\n", " max_new_tokens=256,\n", " eos_token_id=-1,\n", " assistant_model=assistant_model,\n", " )\n", " end = time.time()\n", " duration = duration + (end-start)\n", "\n", "print(f\"\\nThroughput: {(num_iterations * max_new_tokens) / (duration):.4f} (tokens/sec)\")" ] }, { "cell_type": "markdown", "metadata": { "id": "N3qi22BqzzQ5" }, "source": [ "## Conclusion\n", "\n", "\n", "| | Base Model Throughput | |\n", "| :-- | :-- | --: |\n", "| | simple | assisted |\n", "| greedy | 4.9464 | **9.7564** |\n", "| multinomial | 4.9309 | **6.2531** |\n", "\n", "\n", "The throughput increases with assisted generation! šŸŽ‰\n", "\n", "While this process gains speed, it often comes at the cost of increased memory usage, so it's important to balance both metrics." ] }, { "cell_type": "markdown", "metadata": { "id": "jeYf-CTXrv0k" }, "source": [ "## Next Steps\n", "\n", "To know more about speculative decoding we suggest reading:\n", "\n", "- [Assisted Generation](https://huggingface.co/blog/assisted-generation): Learn more about assisted generation.\n", "- [Speculative Decoding Docs](https://huggingface.co/docs/transformers/main/en/generation_strategies#speculative-decoding): See how transformers does decoding with an assistant.\n", "\n", "## Acknowledgements\n", "\n", "1. [Vaibhav Srivastav](https://huggingface.co/reach-vb) for the thorough review and suggestions to make the tutorial better.\n", "2. [Joao Gante](https://huggingface.co/joaogante) for clarifying my doubts on speculative decoding." ] } ], "metadata": { "accelerator": "GPU", "colab": { "authorship_tag": "ABX9TyNYYY/vwph6rjyuUkmXkFdk", "gpuType": "L4", "machine_shape": "hm", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }