in local_gemma/utils/benchmark.py [0:0]
def benchmark(model, assistant_model, tokenizer):
"""
Benchmarkes the throughput of the model. Does some warmup before measuring, to remove compilation time (if
applicable).
"""
max_prompt_length = max(PROMPT_LENGTH)
model_inputs = tokenizer(
["foo bar " * 4000], return_tensors="pt", truncation=True, max_length=max_prompt_length
)
# sanity check
if model_inputs.input_ids.shape[1] != max_prompt_length:
raise ValueError(
f"Benchmark error: Model input length is {model_inputs.input_ids.shape[1]}, but expected to be "
f"{max_prompt_length}."
)
# benchmark
results = {}
for prompt_length in PROMPT_LENGTH:
for max_new_tokens in MAX_NEW_TOKENS:
print(f"\nBenchmarking with prompt_length={prompt_length} and max_new_tokens={max_new_tokens}.")
run_name = f"prompt_length={prompt_length}, max_new_tokens={max_new_tokens}"
generate_kwargs = {
"do_sample": False,
"max_new_tokens": max_new_tokens,
"min_new_tokens": max_new_tokens,
"assistant_model": assistant_model,
}
input_ids = model_inputs.input_ids[:, :prompt_length].to(model.device)
for _ in tqdm(range(WARMUP_RUNS), desc="Warming up"):
model.generate(input_ids, **generate_kwargs)
tokens_per_second = []
for _ in tqdm(range(NUM_RUNS), desc="Benchmarking"):
start = time()
gen_out = model.generate(input_ids, **generate_kwargs)
end = time()
if gen_out.shape[1] != prompt_length + max_new_tokens:
raise ValueError(
f"Benchmark error: Generated output length is {gen_out.shape[1]}, but expected to be "
f"{prompt_length + max_new_tokens}."
)
tokens_per_second.append((max_new_tokens) / (end - start))
results[run_name] = sum(tokens_per_second)/len(tokens_per_second)
print(f"{run_name:40s}: {results[run_name]:2f} tokens per second.\n")
# print results
print("\n\nResults:")
for run_name, throughput in results.items():
print(f"{run_name:40s}: {throughput:2f} tokens per second.")