performance_optimization/torch_compile.py (24 lines of code) (raw):

# This example showcases how to leverage the 3.1 8B Instruct models using # torch.compile to accelerate inference. # # You need CUDA and torch >= 2.3 in order to run this example. import os import torch from transformers import AutoModelForCausalLM, AutoTokenizer os.environ["TOKENIZERS_PARALLELISM"] = "false" # silence warnings when compiling device = "cuda" ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct" model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16) model.to(device) tokenizer = AutoTokenizer.from_pretrained(ckpt) prompt = "Why dogs are so cute?" inputs = tokenizer(prompt, return_tensors="pt").to(device) # Specify the max length (including both the prompt and the response) # When calling `generate` with `cache_implementation="static" later, this is also used to create a `StaticCache` object # with sequence length = `max_length`. The longer the more you will re-use it model.generation_config.max_length = 128 # without `torch.compile`: each call takes ~ 5.0 seconds (on A100 80G + torch 2.3) outputs = model.generate(**inputs, do_sample=False) response = tokenizer.batch_decode(outputs)[0] print(response) # `torch.compile(model, ...)` is not recommended as you compile callbacks # and full generate. We recommend compiling only the forward for now. # "reduce-overhead" will use cudagraphs. model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) model.generation_config.cache_implementation = "static" # with `torch.compile` (on A100 80G + torch 2.3) # 1st call: ~ 90 seconds outputs = model.generate(**inputs, do_sample=False) response = tokenizer.batch_decode(outputs)[0] # 2nd call: ~ 60 seconds outputs = model.generate(**inputs, do_sample=False) response = tokenizer.batch_decode(outputs)[0] # 3nd call: ~ 1.5 seconds outputs = model.generate(**inputs, do_sample=False) response = tokenizer.batch_decode(outputs)[0] print(response)