def main()

in local_gemma/cli.py [0:0]


def main():
    args = parser.parse_args()

    stdout_received = not sys.stdin.isatty()
    if stdout_received:
        input_data = sys.stdin.read()
        args.prompt = args.prompt + ["\n"] + [input_data]

    device = infer_device(args.device)
    dtype = infer_dtype(device, args.dtype)
    generation_kwargs = get_generation_kwargs(args.mode)
    base_prompt = get_prompt(args.mode)
    has_starting_prompt = len(args.prompt) > 0
    model_name = MODEL_NAMES.get(args.model) or args.model
    if args.token is None:
        if get_token() is None:
            print("Using the gated Gemma model requires you to:")
            print("1. Create an account on the Hugging Face Hub: https://huggingface.co/join")
            print("2. Accept the Gemma-2 model terms of use: https://huggingface.co/google/gemma-2-9b")
            print("3. Create an access token and paste it below: https://huggingface.co/settings/tokens")
            login()

    if args.preset == "auto":
        args.preset, spare_memory = infer_memory_requirements(
            model_name, device, trust_remote_code=False, token=args.token
        )

    # Triggers assisted generation on CUDA or MPS devices, assuming the default 9b or 27b models are used. Assisted
    # generation is not beneficial on most CPU settings. Can't be used with the speed preset (more precisely, with
    # `torch.compile`).
    if (
        args.model in ('9b', '27b')
        and ("cuda" in device or device.isdigit() or "mps" in device)
        and args.preset != "speed"
    ):
        assistant_model_name = MODEL_NAMES["2b"]
        if spare_memory / 1e9 > 5:
            assistant_preset = "exact"
        else:
            assistant_preset = "memory"
    else:
        assistant_model_name = None

    if not args.silent:
        print("\nLoading model with the following characteristics:")
        print("- Model name:", model_name)
        print(f"- Assistant model name: {assistant_model_name if assistant_model_name is None else assistant_model_name + f' (loaded with `{assistant_preset}` preset)'}")
        print("- Device:", device)
        print("- Default data type:", str(dtype))
        print("- Optimization preset:", args.preset)
        print("- Generation arguments:", str(generation_kwargs))
        print("- Base prompt:", repr(base_prompt) if len(base_prompt) > 0 else "None")
        print("")
    else:
        logging.disable_progress_bar()

    tokenizer = AutoTokenizer.from_pretrained(model_name, token=args.token)
    is_instruction_tuned = tokenizer.chat_template is not None

    if args.preset == "speed" and device == "cuda" and (has_starting_prompt or not is_instruction_tuned):
        # for single-turn responses, disable torch compile and enable fa2
        # this way, we skip the lengthy compilation step and return the generation to the user as quickly as possible
        torch_compile = False
        attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else None
    else:
        # leave to the preset to decide these settings
        torch_compile = attn_implementation = None

    model = LocalGemma2ForCausalLM.from_pretrained(
        model_name,
        preset=args.preset,
        torch_compile=torch_compile,
        token=args.token,
        torch_dtype=dtype,
        device=device,
        attn_implementation=attn_implementation,
    )
    # TODO(joao): this if shouldn't be needed, fix in transformers
    model._supports_cache_class = True

    if assistant_model_name is not None:
        assistant_model = LocalGemma2ForCausalLM.from_pretrained(
            assistant_model_name, preset=assistant_preset, token=args.token, torch_dtype=dtype, device=device)
    else:
        assistant_model = None

    if args.benchmark:
        benchmark(model=model, assistant_model=assistant_model, tokenizer=tokenizer)
    else:
        if args.seed is not None:
            set_seed(args.seed)

        if device == "mps" and args.max_new_tokens is None:
            print(
                "Setting max new tokens to 1024 for faster mps generation. To bypass this limit, set "
                "`--max_new_tokens=2048`."
            )
            args.max_new_tokens = 1024

        # Note: as of transformers 4.44, assisted generation does NOT work with any cache except dynamic cache
        if args.max_new_tokens is None and assistant_model is None:
            cache = HybridCache(
                model.config,
                max_batch_size=1,
                max_cache_len=model.config.max_position_embeddings,
                device=model.device,
                dtype=model.dtype,
            )
            model.generation_config.cache_implementation = None
        else:
            # when generating using max_new_tokens, update the cache on each generation step to limit memory
            cache = None

        if hasattr(model.forward, "_torchdynamo_orig_callable"):
            print(
                "Compiling the model forward pass. This may take a few minutes, particularly the first time it is "
                "run..."
            )
            if not is_torch_version(">=", "2.4"):
                print(
                    "Install torch>=2.4.0 to cache the FX graphs across runs: https://pytorch.org/get-started/locally/"
                )
            chat_history = [{"role": "user", "content": "The theory of relativity states"}, ]
            # Two warm-up runs: First run warms up model (triton autotuning etc). Second run records the graph and plays it. The third run is the fast path...
            for _ in range(2):
                dummy_inputs = tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)
                model_inputs = tokenizer(dummy_inputs, return_tensors="pt").to(model.device)
                # prefill + generation
                model_tokens = model.generate(**model_inputs, past_key_values=cache, max_new_tokens=16)
                model_output_text = tokenizer.decode(model_tokens[0, model_inputs.input_ids.shape[1]:], skip_special_tokens=True)
                chat_history += [{"role": "assistant", "content": model_output_text},  {"role": "user", "content": "Please repeat the above!"},]
                cache.reset()

        if not args.silent and not has_starting_prompt:
            print_help(is_instruction_tuned=is_instruction_tuned)

        streamer = TextStreamer(tokenizer, skip_prompt=True, **{"skip_special_tokens": True})
        chat_history = []
        while True:
            # Get input to the model
            if has_starting_prompt:
                user_input = " ".join(args.prompt)
            else:
                user_input = input(">>> ")

            # Handle special commands
            if user_input in EXIT_COMMANDS:
                break
            elif user_input in NEW_SESSION_COMMANDS:
                chat_history = []
                if hasattr(cache, "reset"):
                    cache.reset()
                else:
                    cache = None
            elif user_input == "!help":
                print_help()

            # Generate text
            else:
                # Inject the base prompt if the chat history is empty
                if len(chat_history) == 0:
                    user_input = base_prompt + user_input

                chat_history += [{"role": "user", "content": user_input},]
                if is_instruction_tuned:
                    user_input = tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)

                model_inputs = tokenizer(
                    user_input,
                    return_tensors="pt",
                    return_attention_mask=device == "mps",
                )
                input_ids = model_inputs.input_ids

                model_inputs = model_inputs.to(device)
                generation_kwargs.update(
                    {
                        "streamer": streamer,
                        "assistant_model": assistant_model,
                        "past_key_values": cache,
                    }
                )
                if args.max_new_tokens is not None:
                    generation_kwargs["max_new_tokens"] = args.max_new_tokens
                    input_ids_len = input_ids.shape[-1]
                    max_cache_len = args.max_new_tokens + input_ids_len
                    if cache is not None and cache.max_cache_len < max_cache_len:
                        # reset the cache
                        generation_kwargs.pop("past_key_values")
                        generation_kwargs["cache_implementation"] = "hybrid"
                else:
                    generation_kwargs["max_length"] = model.config.max_position_embeddings

                model_tokens = model.generate(**model_inputs, **generation_kwargs)
                model_tokens = model_tokens[0, input_ids.shape[1]:]
                model_output_text = tokenizer.decode(model_tokens, skip_special_tokens=True)
                chat_history += [{"role": "assistant", "content": model_output_text},]

                if is_instruction_tuned:
                    # Sanity check: EOS was removed, ends in "<end_of_turn>\n"
                    tokenized_chat = tokenizer.apply_chat_template(
                        chat_history, tokenize=True, add_generation_prompt=False, return_tensors="pt"
                    ).tolist()[0]
                    assert tokenized_chat[0] == 2
                    assert tokenized_chat[-1] == 108
                    assert tokenized_chat[-2] == 107

            if has_starting_prompt or not is_instruction_tuned:
                break