phi3/olive/phi3.py (260 lines of code) (raw):

# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- import argparse import json import tempfile import time from pathlib import Path import onnxruntime_genai as og from olive.cli.base import save_output_model from olive.common.utils import unescaped_str from olive.workflows import run as olive_run # flake8: noqa: T201 TARGETS = ["cpu", "cuda", "mobile", "web"] TARGET_TO_EP = { "cpu": "CPUExecutionProvider", "mobile": "CPUExecutionProvider", "cuda": "CUDAExecutionProvider", "web": "JsExecutionProvider", } AML_MODEL_Path = { "type": "azureml_registry_model", "registry_name": "azureml", "name": "Phi-3-mini-4k-instruct", "version": "7", } def get_args(raw_args): parser = argparse.ArgumentParser(description="phi3 optimization") parser.add_argument( "--model_path", type=str, default="microsoft/Phi-3-mini-4k-instruct", help="Path to the model to optimize. Can be a hf model id or local path", ) parser.add_argument( "--source", type=str, default="HF", choices=["HF", "AzureML"], help=( "Choose from HF(default), AzureML. If AzureML, model_path is overridden with the Phi-3-mini-4k-instruct" " from azureml model registry" ), ) parser.add_argument( "--target", type=str, default=None, required=True, choices=TARGETS, help="Choose from cpu, cuda, mobile or web", ) parser.add_argument( "--finetune_method", type=str, default=None, choices=["qlora", "lora"], help="Finetune method before onnxruntime optimization", ) quant_group = parser.add_mutually_exclusive_group() quant_group.add_argument( "--quarot", action="store_true", help="Run QuaRot on a Hugging Face PyTorch model", ) quant_group.add_argument( "--awq", action="store_true", help="Run AWQ on the base model or the finetuned model", ) parser.add_argument( "--precision", type=str, default="int4", choices=["fp32", "fp16", "int4"], help=( "Choose from fp32 or int4(default) for cpu target; " "fp32 or fp16 or int4(default) for gpu target; int4(default) for mobile or web" ), ) parser.add_argument( "--inference", action="store_true", help="Run inference with optimized model", ) parser.add_argument( "--prompt", nargs="*", type=str, default=["Write a joke"], help="The prompt text fed into the model. Only used with --inference", ) parser.add_argument( "--chat_template", type=unescaped_str, default=None, help=( "The chat template for the prompt. If not provided, will use default templates for base and finetuned" " models. Only used with --inference" ), ) parser.add_argument( "--max_length", type=int, default=200, help="Max length for generation. Only used with --inference", ) parser.add_argument("--output_dir", type=str, default="models/phi3", help="Output path for optimized model") parser.add_argument( "--cache_dir", type=str, default="cache", help="Path to cache directory", ) return parser.parse_args(raw_args) def main(raw_args=None): args = get_args(raw_args) if args.target in ("mobile", "web") and args.precision != "int4": raise ValueError("mobile or web only supports int4(default)") elif args.target == "cpu" and args.precision == "fp16": raise ValueError("Choose from fp32 or int4(default) for cpu target") if args.inference and args.target == "web": raise ValueError("Web model inference is not supported in this script") # Generate Olive configuration file for specific target print("\nGenerating Olive configuration file...") config_file = generate_config(args) print("Olive configuration file is generated...\n") # Generate optimized model for specific target print("Generating optimized model for", args.target, "...\n") output_path = Path(args.output_dir) with tempfile.TemporaryDirectory() as tempdir: with open(config_file) as f: run_config = json.load(f) if args.quarot: run_config["output_dir"] = args.output_dir else: run_config["output_dir"] = tempdir olive_run(run_config) if args.quarot: return save_output_model(run_config, output_path) if args.inference: if not args.chat_template: args.chat_template = ( "### Question: {input} \n### Answer: " if args.finetune_method else "<|user|>\n{input}<|end|>\n<|assistant|>" ) prompts = "Write a joke" if not args.prompt else "".join(args.prompt) prompts = f"{args.chat_template.format(input=prompts)}" max_length = 200 if not args.max_length else args.max_length genai_run(prompts, str(output_path / "model"), max_length) def use_passes(template_json, *passes): use_data_configs = set() # remove unused passes for key in list(template_json["passes"].keys()): if key not in passes: del template_json["passes"][key] continue for param, value in template_json["passes"][key].items(): if param.endswith("data_config"): use_data_configs.add(value) # remove unused data_configs if use_data_configs: template_json["data_configs"] = [ data_config for data_config in template_json["data_configs"] if data_config["name"] in use_data_configs ] else: del template_json["data_configs"] template_json["pass_flows"] = [passes] return template_json def generate_config(args): json_file_template = "phi3_template.json" with open(json_file_template) as f: template_json = json.load(f) config_prefix = "phi3_run_" if args.quarot: template_json = use_passes(template_json, "quarot") template_json["systems"]["local_system"]["accelerators"] = [ {"device": "GPU", "execution_providers": ["CUDAExecutionProvider"]} ] new_json_file = f"{config_prefix}quarot.json" with open(new_json_file, "w") as f: json.dump(template_json, f, indent=4) return new_json_file # use aml instance of model if args.source == "AzureML": template_json["input_model"]["model_path"] = AML_MODEL_Path else: template_json["input_model"]["model_path"] = args.model_path # finetune passes_to_use = [] if args.finetune_method: # adapters will be fine-tuned and merged into the model passes_to_use.extend([args.finetune_method, "merge_adapter_weights"]) if args.awq: passes_to_use.append("awq") if args.precision != "int4": print("AWQ only supports int4 precision. Changing precision to int4") args.precision = "int4" passes_to_use.append("builder") target = str(args.target) if target == "web": # web doesn't have fp16 io passes_to_use.append("fp32_logits") # use the relevant passes template_json = use_passes(template_json, *passes_to_use) # set the accelerator device = "GPU" if target in ("cuda", "web") else "CPU" template_json["systems"]["local_system"]["accelerators"] = [ {"device": device, "execution_providers": [TARGET_TO_EP[target.lower()]]} ] # set the precision template_json["passes"]["builder"]["precision"] = args.precision if target == "mobile": template_json["passes"]["builder"]["int4_accuracy_level"] = 4 # set cache dir template_json["cache_dir"] = args.cache_dir new_json_file = f"{config_prefix}{target.lower()}_{args.precision}.json" with open(new_json_file, "w") as f: json.dump(template_json, f, indent=4) return new_json_file def genai_run(prompt, model_path, max_length): print("\nModel inference starts...") print("Loading model...") app_started_timestamp = time.time() model = og.Model(model_path) model_loaded_timestamp = time.time() print("Model loaded in {:.2f} seconds".format(model_loaded_timestamp - app_started_timestamp)) print("Creating tokenizer...") tokenizer = og.Tokenizer(model) tokenizer_stream = tokenizer.create_stream() input_tokens = tokenizer.encode(prompt) started_timestamp = time.time() print("Creating generator ...") params = og.GeneratorParams(model) # optimal search options for Phi3 search_options = { "max_length": max_length, "top_k": 40, "top_p": 0.95, "temperature": 0.8, "repetition_penalty": 1.0, } params.set_search_options(**search_options) params.input_ids = input_tokens generator = og.Generator(model, params) print("Generator created") first = True first_token_timestamp = None new_tokens = [] print("\n", prompt) try: while not generator.is_done(): generator.compute_logits() generator.generate_next_token() if first: first_token_timestamp = time.time() first = False new_token = generator.get_next_tokens()[0] print(tokenizer_stream.decode(new_token), end="", flush=True) new_tokens.append(new_token) except KeyboardInterrupt: print(" --control+c pressed, aborting generation--") del generator run_time = time.time() - started_timestamp if first_token_timestamp is None: print("\n\nNo tokens generated") else: print( "\n\n" f"Prompt tokens: {len(input_tokens)}, New tokens: {len(new_tokens)}," f" Time to first: {(first_token_timestamp - started_timestamp):.2f}s," f" New tokens per second: {len(new_tokens)/run_time:.2f} tps" ) if __name__ == "__main__": main()