def main()

in benchmarks/big_model_inference/big_model_inference.py [0:0]


def main():
    transformers.utils.logging.set_verbosity_error()
    args = parse_args()

    if args.torch_dtype is None:
        config = AutoConfig.from_pretrained(args.model_name)
        torch_dtype = getattr(config, "torch_dtype", torch.float32)
    else:
        torch_dtype = getattr(torch, args.torch_dtype)
    model_cls = AutoModelForCausalLM if args.is_causal else AutoModelForSeq2SeqLM
    kwargs = {
        "torch_dtype": torch_dtype,
        "revision": args.model_revision,
    }
    if args.disk_offload:
        kwargs["offload_folder"] = "tmp_offload"
        kwargs["offload_state_dict"] = True

    start_measures = start_measure()
    model = model_cls.from_pretrained(args.model_name, device_map="auto", **kwargs)
    end_measures = end_measure(start_measures)
    log_measures(end_measures, "Model loading")

    module_sizes = compute_module_sizes(model)
    device_size = {v: 0 for v in model.hf_device_map.values()}
    for module, device in model.hf_device_map.items():
        device_size[device] += module_sizes[module]
    message = "\n".join([f"- {device}: {size // 2**20}MiB" for device, size in device_size.items()])
    print(f"\nTheoretical use:\n{message}")

    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)

    start_measures = start_measure()
    generation_times = []
    gen_tokens = []
    texts_outs = []
    for prompt in PROMPTS:
        inputs = tokenizer(prompt, return_tensors="pt").to(0)
        tokens = inputs["input_ids"][0].tolist()
        before_generate = time.time()
        outputs = model.generate(inputs["input_ids"])
        after_generate = time.time()
        outputs = outputs[0].tolist()
        num_gen_tokens = len(outputs) if outputs[: len(tokens)] != tokens else len(outputs) - len(tokens)
        generation_time = after_generate - before_generate

        text_out = tokenizer.decode(outputs, skip_special_tokens=True)
        texts_outs.append(text_out)
        generation_times.append(generation_time)
        gen_tokens.append(num_gen_tokens)
        print(f"Prompt: {prompt}\nGeneration {text_out}\nIn {generation_time:.2f}s for {num_gen_tokens} tokens\n")

    end_measures = end_measure(start_measures)
    log_measures(end_measures, "Model generation")

    generation_times_per_token = [gen / tok for gen, tok in zip(generation_times, gen_tokens)]
    avg_gen = sum(generation_times_per_token) / len(generation_times)
    print(f"Average time of generation per token: {avg_gen:.2f}s")
    print(f"First generation (avg time per token): {generation_times_per_token[0]:.2f}s")
    avg_gen = sum(generation_times_per_token[1:]) / (len(generation_times_per_token) - 1)
    print(f"Average time of generation per token (excluding the first): {avg_gen:.2f}s")