benchmarks/torch.compile/regional_compilation.py (51 lines of code) (raw):

# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch.utils.benchmark import Compare, Timer from transformers import AutoConfig, AutoModelForCausalLM from accelerate.test_utils.testing import get_backend from accelerate.utils import compile_regions torch.set_float32_matmul_precision("high") COMPILE_ITERS = 2 INFERENCE_ITERS = 100 BASELINE = "Baseline" COMPILE_TIME = "Compile time" INFRENCE_TIME = "Inference time" FULL_COMPILATION = "Full compilation" REGIONAL_COMPILATION = "Regional compilation" INFRENCE_STMT = "model(input_ids, use_cache=False)" COMPILE_STMT = f"torch._dynamo.reset(); torch._inductor.utils.clear_inductor_caches(); {INFRENCE_STMT}" torch_device_type, _, _ = get_backend() results = [] for model_id in [ # non-gated llama models "NousResearch/Llama-3.2-1B", "NousResearch/Hermes-3-Llama-3.2-3B", "NousResearch/Hermes-3-Llama-3.1-8B", "NousResearch/Nous-Hermes-Llama2-13b", ]: with torch.device(torch_device_type): config = AutoConfig.from_pretrained(model_id) model = AutoModelForCausalLM.from_config(config).to(dtype=torch.float16).eval() full_compilation_model = torch.compile(model) regional_compilation_model = compile_regions(model) for model, sub_label, description, stmt, iters in [ (model, BASELINE, INFRENCE_TIME, INFRENCE_STMT, INFERENCE_ITERS), (full_compilation_model, FULL_COMPILATION, COMPILE_TIME, COMPILE_STMT, COMPILE_ITERS), (full_compilation_model, FULL_COMPILATION, INFRENCE_TIME, INFRENCE_STMT, INFERENCE_ITERS), (regional_compilation_model, REGIONAL_COMPILATION, COMPILE_TIME, COMPILE_STMT, COMPILE_ITERS), (regional_compilation_model, REGIONAL_COMPILATION, INFRENCE_TIME, INFRENCE_STMT, INFERENCE_ITERS), ]: for batch_size, sequence_length in [(1, 128), (4, 128)]: input_ids = torch.randint( 0, 1000, size=(batch_size, sequence_length), dtype=torch.int64, device=torch_device_type ) results.append( Timer( label=model_id, sub_label=sub_label, description=f"{description} ({batch_size}x{sequence_length})", globals={"model": model, "input_ids": input_ids}, stmt=stmt, ).timeit(number=iters) ) compare = Compare(results) compare.colorize() compare.print()