google/generativeai/notebook/cmd_line_parser.py (365 lines of code) (raw):

# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Parses an LLM command line.""" from __future__ import annotations import argparse import shlex import sys from typing import AbstractSet, Any, Callable, MutableMapping, Sequence from google.generativeai.notebook import argument_parser from google.generativeai.notebook import flag_def from google.generativeai.notebook import input_utils from google.generativeai.notebook import model_registry from google.generativeai.notebook import output_utils from google.generativeai.notebook import parsed_args_lib from google.generativeai.notebook import post_process_utils from google.generativeai.notebook import py_utils from google.generativeai.notebook import sheets_utils from google.generativeai.notebook.lib import llm_function from google.generativeai.notebook.lib import llmfn_inputs_source from google.generativeai.notebook.lib import llmfn_outputs from google.generativeai.notebook.lib import model as model_lib _MIN_CANDIDATE_COUNT = 1 _MAX_CANDIDATE_COUNT = 8 def _validate_input_source_against_placeholders( source: llmfn_inputs_source.LLMFnInputsSource, placeholders: AbstractSet[str], ) -> None: for inputs in source.to_normalized_inputs(): for keyword in placeholders: if keyword not in inputs: raise ValueError('Placeholder "{}" not found in input'.format(keyword)) def _get_resolve_input_from_py_var_fn( placeholders: AbstractSet[str] | None, ) -> Callable[[str], llmfn_inputs_source.LLMFnInputsSource]: def _fn(var_name: str) -> llmfn_inputs_source.LLMFnInputsSource: source = input_utils.get_inputs_source_from_py_var(var_name) if placeholders: _validate_input_source_against_placeholders(source, placeholders) return source return _fn def _resolve_compare_fn_var( name: str, ) -> tuple[str, parsed_args_lib.TextResultCompareFn]: """Resolves a value passed into --compare_fn.""" fn = py_utils.get_py_var(name) if not isinstance(fn, Callable): raise ValueError('Variable "{}" does not contain a Callable object'.format(name)) return name, fn def _resolve_ground_truth_var(name: str) -> Sequence[str]: """Resolves a value passed into --ground_truth.""" value = py_utils.get_py_var(name) # "str" and "bytes" are also Sequences but we want an actual Sequence of # strings, like a list. if not isinstance(value, Sequence) or isinstance(value, str) or isinstance(value, bytes): raise ValueError('Variable "{}" does not contain a Sequence of strings'.format(name)) for x in value: if not isinstance(x, str): raise ValueError('Variable "{}" does not contain a Sequence of strings'.format(name)) return value def _get_resolve_sheets_inputs_fn( placeholders: AbstractSet[str] | None, ) -> Callable[[str], llmfn_inputs_source.LLMFnInputsSource]: def _fn(value: str) -> llmfn_inputs_source.LLMFnInputsSource: sheets_id = sheets_utils.get_sheets_id_from_str(value) source = sheets_utils.SheetsInputs(sheets_id) if placeholders: _validate_input_source_against_placeholders(source, placeholders) return source return _fn def _resolve_sheets_outputs(value: str) -> llmfn_outputs.LLMFnOutputsSink: sheets_id = sheets_utils.get_sheets_id_from_str(value) return sheets_utils.SheetsOutputs(sheets_id) def _add_model_flags( parser: argparse.ArgumentParser, ) -> None: """Adds flags that are related to model selection and config.""" flag_def.EnumFlagDef( name="model_type", short_name="mt", enum_type=model_registry.ModelName, default_value=model_registry.ModelRegistry.DEFAULT_MODEL, help_msg="The type of model to use.", ).add_argument_to_parser(parser) def _check_is_greater_than_or_equal_to_zero(x: float) -> float: if x < 0: raise ValueError("Value should be greater than or equal to zero, got {}".format(x)) return x flag_def.SingleValueFlagDef( name="temperature", short_name="t", parse_type=float, # Use None for default value to indicate that this will use the default # value in Text service. default_value=None, parse_to_dest_type_fn=_check_is_greater_than_or_equal_to_zero, help_msg=( "Controls the randomness of the output. Must be positive. Typical" " values are in the range: [0.0, 1.0]. Higher values produce a more" " random and varied response. A temperature of zero will be" " deterministic." ), ).add_argument_to_parser(parser) flag_def.SingleValueFlagDef( name="model", short_name="m", default_value=None, help_msg=( "The name of the model to use. If not provided, a default model will" " be used." ), ).add_argument_to_parser(parser) def _check_candidate_count_range(x: Any) -> int: if x < _MIN_CANDIDATE_COUNT or x > _MAX_CANDIDATE_COUNT: raise ValueError( "Value should be in the range [{}, {}], got {}".format( _MIN_CANDIDATE_COUNT, _MAX_CANDIDATE_COUNT, x ) ) return int(x) flag_def.SingleValueFlagDef( name="candidate_count", short_name="cc", parse_type=int, # Use None for default value to indicate that this will use the default # value in Text service. default_value=None, parse_to_dest_type_fn=_check_candidate_count_range, help_msg="The number of candidates to produce.", ).add_argument_to_parser(parser) flag_def.BooleanFlagDef( name="unique", help_msg="Whether to dedupe candidates returned by the model.", ).add_argument_to_parser(parser) def _add_input_flags( parser: argparse.ArgumentParser, placeholders: AbstractSet[str] | None, ) -> None: """Adds flags to read inputs from a Python variable or Sheets.""" flag_def.MultiValuesFlagDef( name="inputs", short_name="i", dest_type=llmfn_inputs_source.LLMFnInputsSource, parse_to_dest_type_fn=_get_resolve_input_from_py_var_fn(placeholders), help_msg=( "Optional names of Python variables containing inputs to use to" " instantiate a prompt. The variable must be either: a dictionary" " {'key1': ['val1', 'val2'] ...}, or an instance of LLMFnInputsSource" " such as SheetsInput." ), ).add_argument_to_parser(parser) flag_def.MultiValuesFlagDef( name="sheets_input_names", short_name="si", dest_type=llmfn_inputs_source.LLMFnInputsSource, parse_to_dest_type_fn=_get_resolve_sheets_inputs_fn(placeholders), help_msg=( "Optional names of Google Sheets to read inputs from. This is" " equivalent to using --inputs with the names of variables that are" " instances of SheetsInputs, just more convenient to use." ), ).add_argument_to_parser(parser) def _add_output_flags( parser: argparse.ArgumentParser, ) -> None: """Adds flags to write outputs to a Python variable.""" flag_def.MultiValuesFlagDef( name="outputs", short_name="o", dest_type=llmfn_outputs.LLMFnOutputsSink, parse_to_dest_type_fn=output_utils.get_outputs_sink_from_py_var, help_msg=( "Optional names of Python variables to output to. If the Python" " variable has not already been defined, it will be created. If the" " variable is defined and is an instance of LLMFnOutputsSink, the" " outputs will be written through the sink's write_outputs() method." ), ).add_argument_to_parser(parser) flag_def.MultiValuesFlagDef( name="sheets_output_names", short_name="so", dest_type=llmfn_outputs.LLMFnOutputsSink, parse_to_dest_type_fn=_resolve_sheets_outputs, help_msg=( "Optional names of Google Sheets to write inputs to. This is" " equivalent to using --outputs with the names of variables that are" " instances of SheetsOutputs, just more convenient to use." ), ).add_argument_to_parser(parser) def _add_compare_flags( parser: argparse.ArgumentParser, ) -> None: flag_def.MultiValuesFlagDef( name="compare_fn", dest_type=tuple, parse_to_dest_type_fn=_resolve_compare_fn_var, help_msg=( "An optional function that takes two inputs: (lhs_result, rhs_result)" " which are the results of the left- and right-hand side functions. " "Multiple comparison functions can be provided." ), ).add_argument_to_parser(parser) def _add_eval_flags( parser: argparse.ArgumentParser, ) -> None: flag_def.SingleValueFlagDef( name="ground_truth", required=True, dest_type=Sequence, parse_to_dest_type_fn=_resolve_ground_truth_var, help_msg=( "A variable containing a Sequence of strings representing the ground" " truth that the output of this cell will be compared against. It" " should have the same number of entries as inputs." ), ).add_argument_to_parser(parser) def _create_run_parser( parser: argparse.ArgumentParser, placeholders: AbstractSet[str] | None, ) -> None: """Adds flags for the `run` command. `run` sends one or more prompts to a model. Args: parser: The parser to which flags will be added. placeholders: Placeholders from prompts in the cell contents. """ _add_model_flags(parser) _add_input_flags(parser, placeholders) _add_output_flags(parser) def _create_compile_parser( parser: argparse.ArgumentParser, ) -> None: """Adds flags for the compile command. `compile` "compiles" a prompt and model call into a callable function. Args: parser: The parser to which flags will be added. """ # Add a positional argument for "compile_save_name". def _compile_save_name_fn(var_name: str) -> str: try: py_utils.validate_var_name(var_name) except ValueError as e: # Re-raise as ArgumentError to preserve the original error message. raise argparse.ArgumentError(None, "{}".format(e)) from e return var_name save_name_help = "The name of a Python variable to save the compiled function to." parser.add_argument("compile_save_name", help=save_name_help, type=_compile_save_name_fn) _add_model_flags(parser) def _create_compare_parser( parser: argparse.ArgumentParser, placeholders: AbstractSet[str] | None, ) -> None: """Adds flags for the compare command. Args: parser: The parser to which flags will be added. placeholders: Placeholders from prompts in the compiled functions. """ # Add positional arguments. def _resolve_llm_function_fn( var_name: str, ) -> tuple[str, llm_function.LLMFunction]: try: py_utils.validate_var_name(var_name) except ValueError as e: # Re-raise as ArgumentError to preserve the original error message. raise argparse.ArgumentError(None, "{}".format(e)) from e fn = py_utils.get_py_var(var_name) if not isinstance(fn, llm_function.LLMFunction): raise argparse.ArgumentError( None, '{} is not a function created with the "compile" command'.format(var_name), ) return var_name, fn name_help = ( "The name of a Python variable containing a function previously created" ' with the "compile" command.' ) parser.add_argument("lhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn) parser.add_argument("rhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn) _add_input_flags(parser, placeholders) _add_output_flags(parser) _add_compare_flags(parser) def _create_eval_parser( parser: argparse.ArgumentParser, placeholders: AbstractSet[str] | None, ) -> None: """Adds flags for the eval command. Args: parser: The parser to which flags will be added. placeholders: Placeholders from prompts in the cell contents. """ _add_model_flags(parser) _add_input_flags(parser, placeholders) _add_output_flags(parser) _add_compare_flags(parser) _add_eval_flags(parser) def _create_parser( placeholders: AbstractSet[str] | None, ) -> argparse.ArgumentParser: """Create the full parser.""" system_name = "llm" description = "A system for interacting with LLMs." epilog = "" # Commands parser = argument_parser.ArgumentParser( prog=system_name, description=description, epilog=epilog, ) subparsers = parser.add_subparsers(dest="cmd") _create_run_parser( subparsers.add_parser(parsed_args_lib.CommandName.RUN_CMD.value), placeholders, ) _create_compile_parser(subparsers.add_parser(parsed_args_lib.CommandName.COMPILE_CMD.value)) _create_compare_parser( subparsers.add_parser(parsed_args_lib.CommandName.COMPARE_CMD.value), placeholders, ) _create_eval_parser( subparsers.add_parser(parsed_args_lib.CommandName.EVAL_CMD.value), placeholders, ) return parser def _validate_parsed_args(parsed_args: parsed_args_lib.ParsedArgs) -> None: # If candidate_count is not set (i.e. is None), assuming the default value # is 1. if parsed_args.unique and ( parsed_args.model_args.candidate_count is None or parsed_args.model_args.candidate_count == 1 ): print( '"--unique" works across candidates only: it should be used with' " --candidate_count set to a value greater-than one." ) class CmdLineParser: """Implementation of Magics command line parser.""" # Commands DEFAULT_CMD = parsed_args_lib.CommandName.RUN_CMD # Post-processing operator. PIPE_OP = "|" @classmethod def _split_post_processing_tokens( cls, tokens: Sequence[str], ) -> tuple[Sequence[str], parsed_args_lib.PostProcessingTokens]: """Splits inputs into the command and post processing tokens. The command is represented as a sequence of tokens. See comments on the PostProcessingTokens type alias. E.g. Given: "run --temperature 0.5 | add_score | to_lower_case" The command will be: ["run", "--temperature", "0.5"]. The post processing tokens will be: [["add_score"], ["to_lower_case"]] Args: tokens: The command line tokens. Returns: A tuple of (command line, post processing tokens). """ split_tokens = [] start_idx: int | None = None for token_num, token in enumerate(tokens): if start_idx is None: start_idx = token_num if token == CmdLineParser.PIPE_OP: split_tokens.append(tokens[start_idx:token_num] if start_idx is not None else []) start_idx = None # Add the remaining tokens after the last PIPE_OP. split_tokens.append(tokens[start_idx:] if start_idx is not None else []) return split_tokens[0], split_tokens[1:] @classmethod def _tokenize_line( cls, line: str ) -> tuple[Sequence[str], parsed_args_lib.PostProcessingTokens]: """Parses `line` and returns command line and post processing tokens.""" # Check to make sure there is a command at the start. If not, add the # default command to the list of tokens. tokens = shlex.split(line) if not tokens: tokens = [CmdLineParser.DEFAULT_CMD.value] first_token = tokens[0] # Add default command if the first token is not the help token. if not first_token[0].isalpha() and first_token not in ["-h", "--help"]: tokens = [CmdLineParser.DEFAULT_CMD.value] + tokens # Split line into tokens and post-processing return CmdLineParser._split_post_processing_tokens(tokens) @classmethod def _get_model_args( cls, parsed_results: MutableMapping[str, Any] ) -> tuple[MutableMapping[str, Any], model_lib.ModelArguments]: """Extracts fields for model args from `parsed_results`. Keys specific to model arguments will be removed from `parsed_results`. Args: parsed_results: A dictionary of parsed arguments (from ArgumentParser). It will be modified in place. Returns: A tuple of (updated parsed_results, model arguments). """ model = parsed_results.pop("model", None) temperature = parsed_results.pop("temperature", None) candidate_count = parsed_results.pop("candidate_count", None) model_args = model_lib.ModelArguments( model=model, temperature=temperature, candidate_count=candidate_count, ) return parsed_results, model_args def parse_line( self, line: str, placeholders: AbstractSet[str] | None = None, ) -> tuple[parsed_args_lib.ParsedArgs, parsed_args_lib.PostProcessingTokens]: """Parses the commandline and returns ParsedArgs and post-processing tokens. Args: line: The line to parse (usually contents from cell Magics). placeholders: Placeholders from prompts in the cell contents. Returns: A tuple of (parsed_args, post_processing_tokens). """ tokens, post_processing_tokens = CmdLineParser._tokenize_line(line) parsed_args = self._get_parsed_args_from_cmd_line_tokens( tokens=tokens, placeholders=placeholders ) # Special-case for "compare" command: because the prompts are compiled into # the left- and right-hand side functions rather than in the cell body, we # cannot examine the cell body to get the placeholders. # # Instead we parse the command line twice: once to get the left- and right- # functions, then we query the functions for their placeholders, then # parse the commandline again to validate the inputs. if parsed_args.cmd == parsed_args_lib.CommandName.COMPARE_CMD: assert parsed_args.lhs_name_and_fn is not None assert parsed_args.rhs_name_and_fn is not None _, lhs_fn = parsed_args.lhs_name_and_fn _, rhs_fn = parsed_args.rhs_name_and_fn parsed_args = self._get_parsed_args_from_cmd_line_tokens( tokens=tokens, placeholders=frozenset(lhs_fn.get_placeholders()).union(rhs_fn.get_placeholders()), ) _validate_parsed_args(parsed_args) for expr in post_processing_tokens: post_process_utils.validate_one_post_processing_expression(expr) return parsed_args, post_processing_tokens def _get_parsed_args_from_cmd_line_tokens( self, tokens: Sequence[str], placeholders: AbstractSet[str] | None, ) -> parsed_args_lib.ParsedArgs: """Returns ParsedArgs from a tokenized command line.""" # Create a new parser to avoid reusing the temporary argparse.Namespace # object. results = _create_parser(placeholders).parse_args(tokens) results_dict = vars(results) results_dict["cmd"] = parsed_args_lib.CommandName(results_dict["cmd"]) results_dict, model_args = CmdLineParser._get_model_args(results_dict) results_dict["model_args"] = model_args return parsed_args_lib.ParsedArgs(**results_dict)