def benchmark()

in local_gemma/utils/benchmark.py [0:0]


def benchmark(model, assistant_model, tokenizer):
    """
    Benchmarkes the throughput of the model. Does some warmup before measuring, to remove compilation time (if
    applicable).
    """
    max_prompt_length = max(PROMPT_LENGTH)
    model_inputs = tokenizer(
        ["foo bar " * 4000], return_tensors="pt", truncation=True, max_length=max_prompt_length
    )

    # sanity check
    if model_inputs.input_ids.shape[1] != max_prompt_length:
        raise ValueError(
            f"Benchmark error: Model input length is {model_inputs.input_ids.shape[1]}, but expected to be "
            f"{max_prompt_length}."
        )

    # benchmark
    results = {}
    for prompt_length in PROMPT_LENGTH:
        for max_new_tokens in MAX_NEW_TOKENS:
            print(f"\nBenchmarking with prompt_length={prompt_length} and max_new_tokens={max_new_tokens}.")
            run_name = f"prompt_length={prompt_length}, max_new_tokens={max_new_tokens}"
            generate_kwargs = {
                "do_sample": False,
                "max_new_tokens": max_new_tokens,
                "min_new_tokens": max_new_tokens,
                "assistant_model": assistant_model,
            }

            input_ids = model_inputs.input_ids[:, :prompt_length].to(model.device)
            for _ in tqdm(range(WARMUP_RUNS), desc="Warming up"):
                model.generate(input_ids, **generate_kwargs)

            tokens_per_second = []
            for _ in tqdm(range(NUM_RUNS), desc="Benchmarking"):
                start = time()
                gen_out = model.generate(input_ids, **generate_kwargs)
                end = time()
                if gen_out.shape[1] != prompt_length + max_new_tokens:
                    raise ValueError(
                        f"Benchmark error: Generated output length is {gen_out.shape[1]}, but expected to be "
                        f"{prompt_length + max_new_tokens}."
                    )
                tokens_per_second.append((max_new_tokens) / (end - start))
            results[run_name] = sum(tokens_per_second)/len(tokens_per_second)
            print(f"{run_name:40s}: {results[run_name]:2f} tokens per second.\n")

    # print results
    print("\n\nResults:")
    for run_name, throughput in results.items():
        print(f"{run_name:40s}: {throughput:2f} tokens per second.")