assisted_decoding/assisted_decoding_70B_3B.ipynb (353 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "hoB6HZTw4fvA"
},
"source": [
"# Speedup Llama 3 70B with Speculative Decoding\n",
"\n",
"In this guide, we'll show you how to implement speculative decoding with Llama 3.1 70B model (base) and Llama 3.2 3B model (assistant). Transformers has a `generate` API where we pass an `assistant_model` to enable speculative decoding. By the end, you'll see how this technique can significantly speed up text generation (upto 2x), making your workflows faster and more efficient."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8h73s8MaxqrK"
},
"source": [
"## Imports and Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NaWzzUs5yfwo"
},
"outputs": [],
"source": [
"!pip install -Uq transformers accelerate"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6Kp5KI7qxsXi"
},
"outputs": [],
"source": [
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"import time\n",
"import torch\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "q5lpixXsrJQb"
},
"outputs": [],
"source": [
"# supress the warning in the notebook\n",
"import logging\n",
"import warnings\n",
"logging.getLogger('transformers').setLevel(logging.ERROR)\n",
"warnings.filterwarnings('ignore')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IUEYY5j71ph0"
},
"outputs": [],
"source": [
"# base models\n",
"checkpoint = \"meta-llama/Meta-Llama-3.1-70B\" # <-- Larger Model\n",
"assistant_checkpoint = \"meta-llama/Llama-3.2-3B\" # <-- Smaller Model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FyUZ7dpBmVjL"
},
"source": [
"## Prepare Models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LihenVlfgsuR"
},
"outputs": [],
"source": [
"model = AutoModelForCausalLM.from_pretrained(\n",
" checkpoint,\n",
" device_map=\"auto\",\n",
" torch_dtype=torch.bfloat16,\n",
")\n",
"assistant_model = AutoModelForCausalLM.from_pretrained(\n",
" assistant_checkpoint,\n",
" device_map=\"auto\",\n",
" torch_dtype=torch.bfloat16,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yE5QDkOOmZ64"
},
"source": [
"### Prepare Inputs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "na-Qefklg0nm"
},
"outputs": [],
"source": [
"prompt = \"Alice and Bob\"\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
"model_inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QTXUxJlWmieJ"
},
"source": [
"### Benchmark the text generation speed\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nZQr3rSJg-eO"
},
"outputs": [],
"source": [
"max_new_tokens = 256\n",
"num_iterations = 10"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9t9jh24rrdYb"
},
"source": [
"### Greedy Decoding"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mP8co683-T2M"
},
"outputs": [],
"source": [
"# warmup\n",
"for _ in range(2):\n",
" model.generate(\n",
" **model_inputs,\n",
" do_sample=False,\n",
" assistant_model=assistant_model,\n",
" max_new_tokens=max_new_tokens,\n",
" eos_token_id=-1,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "O482aPT-47xH"
},
"outputs": [],
"source": [
"# without assistance\n",
"print(\"š Big Model Generation\")\n",
"duration = 0\n",
"for _ in tqdm(range(num_iterations)):\n",
" start = time.time()\n",
" outputs = model.generate(\n",
" **model_inputs,\n",
" do_sample=False,\n",
" max_new_tokens=256,\n",
" eos_token_id=-1,\n",
" )\n",
" end = time.time()\n",
" duration = duration + (end-start)\n",
"\n",
"print(f\"\\nThroughput: {(num_iterations * max_new_tokens) / (duration):.4f} (tokens/sec)\")\n",
"\n",
"\n",
"# with assistance\n",
"print(\"\\n\\nš¤ Big Model Generation with Assistance\")\n",
"duration = 0\n",
"for _ in tqdm(range(10)):\n",
" start = time.time()\n",
" outputs = model.generate(\n",
" **model_inputs,\n",
" do_sample=False,\n",
" max_new_tokens=256,\n",
" eos_token_id=-1,\n",
" assistant_model=assistant_model,\n",
" )\n",
" end = time.time()\n",
" duration = duration + (end-start)\n",
"\n",
"print(f\"\\nThroughput: {(num_iterations * max_new_tokens) / (duration):.4f} (tokens/sec)\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fu8FmSpUrgU7"
},
"source": [
"### Multinomial Decoding"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wd6gbarl8GTQ"
},
"outputs": [],
"source": [
"# warmup\n",
"for _ in range(2):\n",
" model.generate(\n",
" **model_inputs,\n",
" do_sample=True,\n",
" temperature=0.2,\n",
" assistant_model=assistant_model,\n",
" max_new_tokens=max_new_tokens,\n",
" eos_token_id=-1,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "asU-_sjOgeJP"
},
"outputs": [],
"source": [
"# without assistance\n",
"print(\"š Big Model Generation (multinomial)\")\n",
"duration = 0\n",
"for _ in tqdm(range(num_iterations)):\n",
" start = time.time()\n",
" outputs = model.generate(\n",
" **model_inputs,\n",
" do_sample=True,\n",
" temperature=0.2,\n",
" max_new_tokens=256,\n",
" eos_token_id=-1,\n",
" )\n",
" end = time.time()\n",
" duration = duration + (end-start)\n",
"\n",
"print(f\"\\nThroughput: {(num_iterations * max_new_tokens) / (duration):.4f} (tokens/sec)\")\n",
"\n",
"\n",
"# with assistance\n",
"print(\"\\n\\nš¤ Big Model Generation with Assistance (multinomial)\")\n",
"duration = 0\n",
"for _ in tqdm(range(10)):\n",
" start = time.time()\n",
" outputs = model.generate(\n",
" **model_inputs,\n",
" do_sample=True,\n",
" temperature=0.2,\n",
" max_new_tokens=256,\n",
" eos_token_id=-1,\n",
" assistant_model=assistant_model,\n",
" )\n",
" end = time.time()\n",
" duration = duration + (end-start)\n",
"\n",
"print(f\"\\nThroughput: {(num_iterations * max_new_tokens) / (duration):.4f} (tokens/sec)\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N3qi22BqzzQ5"
},
"source": [
"## Conclusion\n",
"\n",
"\n",
"| | Base Model Throughput | |\n",
"| :-- | :-- | --: |\n",
"| | simple | assisted |\n",
"| greedy | 4.9464 | **9.7564** |\n",
"| multinomial | 4.9309 | **6.2531** |\n",
"\n",
"\n",
"The throughput increases with assisted generation! š\n",
"\n",
"While this process gains speed, it often comes at the cost of increased memory usage, so it's important to balance both metrics."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jeYf-CTXrv0k"
},
"source": [
"## Next Steps\n",
"\n",
"To know more about speculative decoding we suggest reading:\n",
"\n",
"- [Assisted Generation](https://huggingface.co/blog/assisted-generation): Learn more about assisted generation.\n",
"- [Speculative Decoding Docs](https://huggingface.co/docs/transformers/main/en/generation_strategies#speculative-decoding): See how transformers does decoding with an assistant.\n",
"\n",
"## Acknowledgements\n",
"\n",
"1. [Vaibhav Srivastav](https://huggingface.co/reach-vb) for the thorough review and suggestions to make the tutorial better.\n",
"2. [Joao Gante](https://huggingface.co/joaogante) for clarifying my doubts on speculative decoding."
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"authorship_tag": "ABX9TyNYYY/vwph6rjyuUkmXkFdk",
"gpuType": "L4",
"machine_shape": "hm",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}