performance_optimization/torch_compile_with_torchao.ipynb (152 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load quantized model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig\n",
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\" # silence warnings when compiling\n",
"\n",
"device = \"cuda\"\n",
"ckpt = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n",
"\n",
"# Set the quantization config\n",
"# You can choose between int4_weight_only (4-bit), int8_weight_only (8-bit) and int8_dynamic_activation_int8_weight (8-bit)\n",
"# group_size is only for int4_weight_only and needs to be one of [32,64,128,256]\n",
"quantization_config = TorchAoConfig(quant_type=\"int4_weight_only\", group_size=128)\n",
"# Loading the quantized model takes 6218 MB\n",
"model = AutoModelForCausalLM.from_pretrained(ckpt,\n",
" torch_dtype=torch.bfloat16,\n",
" quantization_config=quantization_config,\n",
" device_map=device\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Set tokenizer and generation config"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(ckpt)\n",
"\n",
"prompt = \"Why dogs are so cute?\"\n",
"inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
"\n",
"# Specify the max length (including both the prompt and the response)\n",
"# When calling `generate` with `cache_implementation=\"static\" later, this is also used to create a `StaticCache` object\n",
"# with sequence length = `max_length`. The longer the more you will re-use it\n",
"model.generation_config.max_length = 128\n",
"model.generation_config.cache_implementation = \"static\"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### test generate without compile"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# without `torch.compile`: each call takes ~ 5.0 seconds (on A100 80G + torch 2.5)\n",
"outputs = model.generate(**inputs, do_sample=False)\n",
"response = tokenizer.batch_decode(outputs)[0]\n",
"response"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### test generate with compile"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# `torch.compile(model, ...)` is not recommended as you compile callbacks\n",
"# and full generate. We recommend compiling only the forward for now. \n",
"# \"reduce-overhead\" will use cudagraphs. \n",
"# Compile with torchao requires torch 2.5 or torch nighlty if the version is still not on pipy\n",
"model.forward = torch.compile(model.forward, mode=\"reduce-overhead\", fullgraph=True)\n",
"\n",
"# with `torch.compile` (on A100 80G + torch 2.5)\n",
"# 1st call: ~ 60 seconds\n",
"outputs = model.generate(**inputs, do_sample=False)\n",
"response = tokenizer.batch_decode(outputs)[0]\n",
"response"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 2nd call: ~ 1.5 seconds\n",
"outputs = model.generate(**inputs, do_sample=False)\n",
"response = tokenizer.batch_decode(outputs)[0]\n",
"response"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 3nd call: ~ 1 seconds\n",
"outputs = model.generate(**inputs, do_sample=False)\n",
"response = tokenizer.batch_decode(outputs)[0]\n",
"print(response)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}