local_gemma/cli.py (284 lines of code) (raw):
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
import argparse
import sys
import torch
from transformers import AutoTokenizer, TextStreamer, set_seed
from transformers.cache_utils import HybridCache
from transformers.utils import logging, is_flash_attn_2_available
from accelerate.utils import is_torch_version
from huggingface_hub import get_token, login
from local_gemma import LocalGemma2ForCausalLM
from .utils.benchmark import benchmark
from .utils.config import (
DTYPE_MODIFIER, infer_device, infer_dtype, get_prompt, get_generation_kwargs, infer_memory_requirements
)
torch.set_float32_matmul_precision("high")
EXIT_COMMANDS = {"!exit", "quit", "quit()", "!exit()", "!quit", "!quit()"}
NEW_SESSION_COMMANDS = {"!new session", "!new session()", "!new chat", "!new chat()", "!new", "!reset"}
MODEL_NAMES = {
"2b": "google/gemma-2-2b-it",
"9b": "google/gemma-2-9b-it",
"27b": "google/gemma-2-27b-it",
}
parser = argparse.ArgumentParser(description="Local Gemma")
# Prompt argument
parser.add_argument(
"prompt",
type=str,
nargs="*",
help=(
"Prompt to the model. For an interactive session, leave this field empty."
),
)
# Other control arguments
parser.add_argument(
"--model",
type=str,
default="9b",
help=(
"Size of Gemma 2 instruct model to be used in the application ('2b', '9b' or '27b') or, alternatively, a "
"Hugging Face repo. Defaults to '9b'."
),
)
parser.add_argument(
"--token",
type=str,
help="Authentication token for the model. Required to download the model into a local cache.",
)
parser.add_argument(
"--preset",
type=str,
choices=["auto", "exact", "speed", "memory", "memory_extreme"],
default="auto",
help=(
"Sets the optimization strategy for the local model deployment. Defaults to 'auto', which selects the best "
"strategy for your device. 'exact' maximises accuracy, 'speed' maximizes speed, 'memory' reduces "
"memory requirements through quantization, and 'memory_extreme' minimises memory requirements."
),
)
parser.add_argument(
"--mode",
type=str,
choices=["chat", "factual", "creative"],
default="chat",
help=(
"Sets the mode of operation of the model. 'chat' is optimized for general conversation, 'factual' is designed "
"to minimize hallucinations, and 'creative' is optimized for creative writing. Note that 'factual' and "
"'creative' prepend text to your prompt."
),
)
parser.add_argument(
"--max_new_tokens",
type=int,
help=(
"Maximum number of tokens to be used in each generation round. By default it relies on the model to emit an "
"EOS token, and generates up to the pretrained length."
),
)
parser.add_argument(
"--device",
type=str,
help="Forces a specific device to be used. By default uses cuda > mps > cpu, depending on availability.",
)
parser.add_argument(
"--dtype",
type=str,
help="The dtype in which computations are performed. Defaults to the dtype set by --preset",
)
parser.add_argument(
"--silent",
action="store_true",
help="Does NOT print any output except for the model outputs.",
)
# Debugging arguments
parser.add_argument(
"--seed",
type=int,
help="Seed for text generation. Optional, use for reproducibility.",
)
parser.add_argument(
"--benchmark",
action="store_true",
help="Runs a throughput benchmark on your device.",
)
def print_help(is_instruction_tuned: bool = True):
if is_instruction_tuned:
print("\nYou can now interact with the model through a conversation. A few tips:")
print("- Initialize the program with '--silent' to hide all non-model messages")
print("- Input '!exit' to leave the program")
print("- Input '!new session' to reset the conversation")
print("- Input '!help' to print this message again")
else:
print("\nYou can now pass a prompt to the base model to generate a single response.")
print("Tip: for multi-turn conversation, use an instruction tuned model, such as `google/gemma-2-9b-it`.")
print("")
def main():
args = parser.parse_args()
stdout_received = not sys.stdin.isatty()
if stdout_received:
input_data = sys.stdin.read()
args.prompt = args.prompt + ["\n"] + [input_data]
device = infer_device(args.device)
dtype = infer_dtype(device, args.dtype)
generation_kwargs = get_generation_kwargs(args.mode)
base_prompt = get_prompt(args.mode)
has_starting_prompt = len(args.prompt) > 0
model_name = MODEL_NAMES.get(args.model) or args.model
if args.token is None:
if get_token() is None:
print("Using the gated Gemma model requires you to:")
print("1. Create an account on the Hugging Face Hub: https://huggingface.co/join")
print("2. Accept the Gemma-2 model terms of use: https://huggingface.co/google/gemma-2-9b")
print("3. Create an access token and paste it below: https://huggingface.co/settings/tokens")
login()
if args.preset == "auto":
args.preset, spare_memory = infer_memory_requirements(
model_name, device, trust_remote_code=False, token=args.token
)
# Triggers assisted generation on CUDA or MPS devices, assuming the default 9b or 27b models are used. Assisted
# generation is not beneficial on most CPU settings. Can't be used with the speed preset (more precisely, with
# `torch.compile`).
if (
args.model in ('9b', '27b')
and ("cuda" in device or device.isdigit() or "mps" in device)
and args.preset != "speed"
):
assistant_model_name = MODEL_NAMES["2b"]
if spare_memory / 1e9 > 5:
assistant_preset = "exact"
else:
assistant_preset = "memory"
else:
assistant_model_name = None
if not args.silent:
print("\nLoading model with the following characteristics:")
print("- Model name:", model_name)
print(f"- Assistant model name: {assistant_model_name if assistant_model_name is None else assistant_model_name + f' (loaded with `{assistant_preset}` preset)'}")
print("- Device:", device)
print("- Default data type:", str(dtype))
print("- Optimization preset:", args.preset)
print("- Generation arguments:", str(generation_kwargs))
print("- Base prompt:", repr(base_prompt) if len(base_prompt) > 0 else "None")
print("")
else:
logging.disable_progress_bar()
tokenizer = AutoTokenizer.from_pretrained(model_name, token=args.token)
is_instruction_tuned = tokenizer.chat_template is not None
if args.preset == "speed" and device == "cuda" and (has_starting_prompt or not is_instruction_tuned):
# for single-turn responses, disable torch compile and enable fa2
# this way, we skip the lengthy compilation step and return the generation to the user as quickly as possible
torch_compile = False
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else None
else:
# leave to the preset to decide these settings
torch_compile = attn_implementation = None
model = LocalGemma2ForCausalLM.from_pretrained(
model_name,
preset=args.preset,
torch_compile=torch_compile,
token=args.token,
torch_dtype=dtype,
device=device,
attn_implementation=attn_implementation,
)
# TODO(joao): this if shouldn't be needed, fix in transformers
model._supports_cache_class = True
if assistant_model_name is not None:
assistant_model = LocalGemma2ForCausalLM.from_pretrained(
assistant_model_name, preset=assistant_preset, token=args.token, torch_dtype=dtype, device=device)
else:
assistant_model = None
if args.benchmark:
benchmark(model=model, assistant_model=assistant_model, tokenizer=tokenizer)
else:
if args.seed is not None:
set_seed(args.seed)
if device == "mps" and args.max_new_tokens is None:
print(
"Setting max new tokens to 1024 for faster mps generation. To bypass this limit, set "
"`--max_new_tokens=2048`."
)
args.max_new_tokens = 1024
# Note: as of transformers 4.44, assisted generation does NOT work with any cache except dynamic cache
if args.max_new_tokens is None and assistant_model is None:
cache = HybridCache(
model.config,
max_batch_size=1,
max_cache_len=model.config.max_position_embeddings,
device=model.device,
dtype=model.dtype,
)
model.generation_config.cache_implementation = None
else:
# when generating using max_new_tokens, update the cache on each generation step to limit memory
cache = None
if hasattr(model.forward, "_torchdynamo_orig_callable"):
print(
"Compiling the model forward pass. This may take a few minutes, particularly the first time it is "
"run..."
)
if not is_torch_version(">=", "2.4"):
print(
"Install torch>=2.4.0 to cache the FX graphs across runs: https://pytorch.org/get-started/locally/"
)
chat_history = [{"role": "user", "content": "The theory of relativity states"}, ]
# Two warm-up runs: First run warms up model (triton autotuning etc). Second run records the graph and plays it. The third run is the fast path...
for _ in range(2):
dummy_inputs = tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer(dummy_inputs, return_tensors="pt").to(model.device)
# prefill + generation
model_tokens = model.generate(**model_inputs, past_key_values=cache, max_new_tokens=16)
model_output_text = tokenizer.decode(model_tokens[0, model_inputs.input_ids.shape[1]:], skip_special_tokens=True)
chat_history += [{"role": "assistant", "content": model_output_text}, {"role": "user", "content": "Please repeat the above!"},]
cache.reset()
if not args.silent and not has_starting_prompt:
print_help(is_instruction_tuned=is_instruction_tuned)
streamer = TextStreamer(tokenizer, skip_prompt=True, **{"skip_special_tokens": True})
chat_history = []
while True:
# Get input to the model
if has_starting_prompt:
user_input = " ".join(args.prompt)
else:
user_input = input(">>> ")
# Handle special commands
if user_input in EXIT_COMMANDS:
break
elif user_input in NEW_SESSION_COMMANDS:
chat_history = []
if hasattr(cache, "reset"):
cache.reset()
else:
cache = None
elif user_input == "!help":
print_help()
# Generate text
else:
# Inject the base prompt if the chat history is empty
if len(chat_history) == 0:
user_input = base_prompt + user_input
chat_history += [{"role": "user", "content": user_input},]
if is_instruction_tuned:
user_input = tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer(
user_input,
return_tensors="pt",
return_attention_mask=device == "mps",
)
input_ids = model_inputs.input_ids
model_inputs = model_inputs.to(device)
generation_kwargs.update(
{
"streamer": streamer,
"assistant_model": assistant_model,
"past_key_values": cache,
}
)
if args.max_new_tokens is not None:
generation_kwargs["max_new_tokens"] = args.max_new_tokens
input_ids_len = input_ids.shape[-1]
max_cache_len = args.max_new_tokens + input_ids_len
if cache is not None and cache.max_cache_len < max_cache_len:
# reset the cache
generation_kwargs.pop("past_key_values")
generation_kwargs["cache_implementation"] = "hybrid"
else:
generation_kwargs["max_length"] = model.config.max_position_embeddings
model_tokens = model.generate(**model_inputs, **generation_kwargs)
model_tokens = model_tokens[0, input_ids.shape[1]:]
model_output_text = tokenizer.decode(model_tokens, skip_special_tokens=True)
chat_history += [{"role": "assistant", "content": model_output_text},]
if is_instruction_tuned:
# Sanity check: EOS was removed, ends in "<end_of_turn>\n"
tokenized_chat = tokenizer.apply_chat_template(
chat_history, tokenize=True, add_generation_prompt=False, return_tensors="pt"
).tolist()[0]
assert tokenized_chat[0] == 2
assert tokenized_chat[-1] == 108
assert tokenized_chat[-2] == 107
if has_starting_prompt or not is_instruction_tuned:
break
if __name__ == '__main__':
main()