local_gemma/utils/benchmark.py (46 lines of code) (raw):

# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from time import time from tqdm import tqdm PROMPT_LENGTH = [64, 2048] MAX_NEW_TOKENS = [64, 2048] NUM_RUNS = 5 WARMUP_RUNS = 3 def benchmark(model, assistant_model, tokenizer): """ Benchmarkes the throughput of the model. Does some warmup before measuring, to remove compilation time (if applicable). """ max_prompt_length = max(PROMPT_LENGTH) model_inputs = tokenizer( ["foo bar " * 4000], return_tensors="pt", truncation=True, max_length=max_prompt_length ) # sanity check if model_inputs.input_ids.shape[1] != max_prompt_length: raise ValueError( f"Benchmark error: Model input length is {model_inputs.input_ids.shape[1]}, but expected to be " f"{max_prompt_length}." ) # benchmark results = {} for prompt_length in PROMPT_LENGTH: for max_new_tokens in MAX_NEW_TOKENS: print(f"\nBenchmarking with prompt_length={prompt_length} and max_new_tokens={max_new_tokens}.") run_name = f"prompt_length={prompt_length}, max_new_tokens={max_new_tokens}" generate_kwargs = { "do_sample": False, "max_new_tokens": max_new_tokens, "min_new_tokens": max_new_tokens, "assistant_model": assistant_model, } input_ids = model_inputs.input_ids[:, :prompt_length].to(model.device) for _ in tqdm(range(WARMUP_RUNS), desc="Warming up"): model.generate(input_ids, **generate_kwargs) tokens_per_second = [] for _ in tqdm(range(NUM_RUNS), desc="Benchmarking"): start = time() gen_out = model.generate(input_ids, **generate_kwargs) end = time() if gen_out.shape[1] != prompt_length + max_new_tokens: raise ValueError( f"Benchmark error: Generated output length is {gen_out.shape[1]}, but expected to be " f"{prompt_length + max_new_tokens}." ) tokens_per_second.append((max_new_tokens) / (end - start)) results[run_name] = sum(tokens_per_second)/len(tokens_per_second) print(f"{run_name:40s}: {results[run_name]:2f} tokens per second.\n") # print results print("\n\nResults:") for run_name, throughput in results.items(): print(f"{run_name:40s}: {throughput:2f} tokens per second.")