google/generativeai/notebook/cmd_line_parser.py (365 lines of code) (raw):
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Parses an LLM command line."""
from __future__ import annotations
import argparse
import shlex
import sys
from typing import AbstractSet, Any, Callable, MutableMapping, Sequence
from google.generativeai.notebook import argument_parser
from google.generativeai.notebook import flag_def
from google.generativeai.notebook import input_utils
from google.generativeai.notebook import model_registry
from google.generativeai.notebook import output_utils
from google.generativeai.notebook import parsed_args_lib
from google.generativeai.notebook import post_process_utils
from google.generativeai.notebook import py_utils
from google.generativeai.notebook import sheets_utils
from google.generativeai.notebook.lib import llm_function
from google.generativeai.notebook.lib import llmfn_inputs_source
from google.generativeai.notebook.lib import llmfn_outputs
from google.generativeai.notebook.lib import model as model_lib
_MIN_CANDIDATE_COUNT = 1
_MAX_CANDIDATE_COUNT = 8
def _validate_input_source_against_placeholders(
source: llmfn_inputs_source.LLMFnInputsSource,
placeholders: AbstractSet[str],
) -> None:
for inputs in source.to_normalized_inputs():
for keyword in placeholders:
if keyword not in inputs:
raise ValueError('Placeholder "{}" not found in input'.format(keyword))
def _get_resolve_input_from_py_var_fn(
placeholders: AbstractSet[str] | None,
) -> Callable[[str], llmfn_inputs_source.LLMFnInputsSource]:
def _fn(var_name: str) -> llmfn_inputs_source.LLMFnInputsSource:
source = input_utils.get_inputs_source_from_py_var(var_name)
if placeholders:
_validate_input_source_against_placeholders(source, placeholders)
return source
return _fn
def _resolve_compare_fn_var(
name: str,
) -> tuple[str, parsed_args_lib.TextResultCompareFn]:
"""Resolves a value passed into --compare_fn."""
fn = py_utils.get_py_var(name)
if not isinstance(fn, Callable):
raise ValueError('Variable "{}" does not contain a Callable object'.format(name))
return name, fn
def _resolve_ground_truth_var(name: str) -> Sequence[str]:
"""Resolves a value passed into --ground_truth."""
value = py_utils.get_py_var(name)
# "str" and "bytes" are also Sequences but we want an actual Sequence of
# strings, like a list.
if not isinstance(value, Sequence) or isinstance(value, str) or isinstance(value, bytes):
raise ValueError('Variable "{}" does not contain a Sequence of strings'.format(name))
for x in value:
if not isinstance(x, str):
raise ValueError('Variable "{}" does not contain a Sequence of strings'.format(name))
return value
def _get_resolve_sheets_inputs_fn(
placeholders: AbstractSet[str] | None,
) -> Callable[[str], llmfn_inputs_source.LLMFnInputsSource]:
def _fn(value: str) -> llmfn_inputs_source.LLMFnInputsSource:
sheets_id = sheets_utils.get_sheets_id_from_str(value)
source = sheets_utils.SheetsInputs(sheets_id)
if placeholders:
_validate_input_source_against_placeholders(source, placeholders)
return source
return _fn
def _resolve_sheets_outputs(value: str) -> llmfn_outputs.LLMFnOutputsSink:
sheets_id = sheets_utils.get_sheets_id_from_str(value)
return sheets_utils.SheetsOutputs(sheets_id)
def _add_model_flags(
parser: argparse.ArgumentParser,
) -> None:
"""Adds flags that are related to model selection and config."""
flag_def.EnumFlagDef(
name="model_type",
short_name="mt",
enum_type=model_registry.ModelName,
default_value=model_registry.ModelRegistry.DEFAULT_MODEL,
help_msg="The type of model to use.",
).add_argument_to_parser(parser)
def _check_is_greater_than_or_equal_to_zero(x: float) -> float:
if x < 0:
raise ValueError("Value should be greater than or equal to zero, got {}".format(x))
return x
flag_def.SingleValueFlagDef(
name="temperature",
short_name="t",
parse_type=float,
# Use None for default value to indicate that this will use the default
# value in Text service.
default_value=None,
parse_to_dest_type_fn=_check_is_greater_than_or_equal_to_zero,
help_msg=(
"Controls the randomness of the output. Must be positive. Typical"
" values are in the range: [0.0, 1.0]. Higher values produce a more"
" random and varied response. A temperature of zero will be"
" deterministic."
),
).add_argument_to_parser(parser)
flag_def.SingleValueFlagDef(
name="model",
short_name="m",
default_value=None,
help_msg=(
"The name of the model to use. If not provided, a default model will" " be used."
),
).add_argument_to_parser(parser)
def _check_candidate_count_range(x: Any) -> int:
if x < _MIN_CANDIDATE_COUNT or x > _MAX_CANDIDATE_COUNT:
raise ValueError(
"Value should be in the range [{}, {}], got {}".format(
_MIN_CANDIDATE_COUNT, _MAX_CANDIDATE_COUNT, x
)
)
return int(x)
flag_def.SingleValueFlagDef(
name="candidate_count",
short_name="cc",
parse_type=int,
# Use None for default value to indicate that this will use the default
# value in Text service.
default_value=None,
parse_to_dest_type_fn=_check_candidate_count_range,
help_msg="The number of candidates to produce.",
).add_argument_to_parser(parser)
flag_def.BooleanFlagDef(
name="unique",
help_msg="Whether to dedupe candidates returned by the model.",
).add_argument_to_parser(parser)
def _add_input_flags(
parser: argparse.ArgumentParser,
placeholders: AbstractSet[str] | None,
) -> None:
"""Adds flags to read inputs from a Python variable or Sheets."""
flag_def.MultiValuesFlagDef(
name="inputs",
short_name="i",
dest_type=llmfn_inputs_source.LLMFnInputsSource,
parse_to_dest_type_fn=_get_resolve_input_from_py_var_fn(placeholders),
help_msg=(
"Optional names of Python variables containing inputs to use to"
" instantiate a prompt. The variable must be either: a dictionary"
" {'key1': ['val1', 'val2'] ...}, or an instance of LLMFnInputsSource"
" such as SheetsInput."
),
).add_argument_to_parser(parser)
flag_def.MultiValuesFlagDef(
name="sheets_input_names",
short_name="si",
dest_type=llmfn_inputs_source.LLMFnInputsSource,
parse_to_dest_type_fn=_get_resolve_sheets_inputs_fn(placeholders),
help_msg=(
"Optional names of Google Sheets to read inputs from. This is"
" equivalent to using --inputs with the names of variables that are"
" instances of SheetsInputs, just more convenient to use."
),
).add_argument_to_parser(parser)
def _add_output_flags(
parser: argparse.ArgumentParser,
) -> None:
"""Adds flags to write outputs to a Python variable."""
flag_def.MultiValuesFlagDef(
name="outputs",
short_name="o",
dest_type=llmfn_outputs.LLMFnOutputsSink,
parse_to_dest_type_fn=output_utils.get_outputs_sink_from_py_var,
help_msg=(
"Optional names of Python variables to output to. If the Python"
" variable has not already been defined, it will be created. If the"
" variable is defined and is an instance of LLMFnOutputsSink, the"
" outputs will be written through the sink's write_outputs() method."
),
).add_argument_to_parser(parser)
flag_def.MultiValuesFlagDef(
name="sheets_output_names",
short_name="so",
dest_type=llmfn_outputs.LLMFnOutputsSink,
parse_to_dest_type_fn=_resolve_sheets_outputs,
help_msg=(
"Optional names of Google Sheets to write inputs to. This is"
" equivalent to using --outputs with the names of variables that are"
" instances of SheetsOutputs, just more convenient to use."
),
).add_argument_to_parser(parser)
def _add_compare_flags(
parser: argparse.ArgumentParser,
) -> None:
flag_def.MultiValuesFlagDef(
name="compare_fn",
dest_type=tuple,
parse_to_dest_type_fn=_resolve_compare_fn_var,
help_msg=(
"An optional function that takes two inputs: (lhs_result, rhs_result)"
" which are the results of the left- and right-hand side functions. "
"Multiple comparison functions can be provided."
),
).add_argument_to_parser(parser)
def _add_eval_flags(
parser: argparse.ArgumentParser,
) -> None:
flag_def.SingleValueFlagDef(
name="ground_truth",
required=True,
dest_type=Sequence,
parse_to_dest_type_fn=_resolve_ground_truth_var,
help_msg=(
"A variable containing a Sequence of strings representing the ground"
" truth that the output of this cell will be compared against. It"
" should have the same number of entries as inputs."
),
).add_argument_to_parser(parser)
def _create_run_parser(
parser: argparse.ArgumentParser,
placeholders: AbstractSet[str] | None,
) -> None:
"""Adds flags for the `run` command.
`run` sends one or more prompts to a model.
Args:
parser: The parser to which flags will be added.
placeholders: Placeholders from prompts in the cell contents.
"""
_add_model_flags(parser)
_add_input_flags(parser, placeholders)
_add_output_flags(parser)
def _create_compile_parser(
parser: argparse.ArgumentParser,
) -> None:
"""Adds flags for the compile command.
`compile` "compiles" a prompt and model call into a callable function.
Args:
parser: The parser to which flags will be added.
"""
# Add a positional argument for "compile_save_name".
def _compile_save_name_fn(var_name: str) -> str:
try:
py_utils.validate_var_name(var_name)
except ValueError as e:
# Re-raise as ArgumentError to preserve the original error message.
raise argparse.ArgumentError(None, "{}".format(e)) from e
return var_name
save_name_help = "The name of a Python variable to save the compiled function to."
parser.add_argument("compile_save_name", help=save_name_help, type=_compile_save_name_fn)
_add_model_flags(parser)
def _create_compare_parser(
parser: argparse.ArgumentParser,
placeholders: AbstractSet[str] | None,
) -> None:
"""Adds flags for the compare command.
Args:
parser: The parser to which flags will be added.
placeholders: Placeholders from prompts in the compiled functions.
"""
# Add positional arguments.
def _resolve_llm_function_fn(
var_name: str,
) -> tuple[str, llm_function.LLMFunction]:
try:
py_utils.validate_var_name(var_name)
except ValueError as e:
# Re-raise as ArgumentError to preserve the original error message.
raise argparse.ArgumentError(None, "{}".format(e)) from e
fn = py_utils.get_py_var(var_name)
if not isinstance(fn, llm_function.LLMFunction):
raise argparse.ArgumentError(
None,
'{} is not a function created with the "compile" command'.format(var_name),
)
return var_name, fn
name_help = (
"The name of a Python variable containing a function previously created"
' with the "compile" command.'
)
parser.add_argument("lhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn)
parser.add_argument("rhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn)
_add_input_flags(parser, placeholders)
_add_output_flags(parser)
_add_compare_flags(parser)
def _create_eval_parser(
parser: argparse.ArgumentParser,
placeholders: AbstractSet[str] | None,
) -> None:
"""Adds flags for the eval command.
Args:
parser: The parser to which flags will be added.
placeholders: Placeholders from prompts in the cell contents.
"""
_add_model_flags(parser)
_add_input_flags(parser, placeholders)
_add_output_flags(parser)
_add_compare_flags(parser)
_add_eval_flags(parser)
def _create_parser(
placeholders: AbstractSet[str] | None,
) -> argparse.ArgumentParser:
"""Create the full parser."""
system_name = "llm"
description = "A system for interacting with LLMs."
epilog = ""
# Commands
parser = argument_parser.ArgumentParser(
prog=system_name,
description=description,
epilog=epilog,
)
subparsers = parser.add_subparsers(dest="cmd")
_create_run_parser(
subparsers.add_parser(parsed_args_lib.CommandName.RUN_CMD.value),
placeholders,
)
_create_compile_parser(subparsers.add_parser(parsed_args_lib.CommandName.COMPILE_CMD.value))
_create_compare_parser(
subparsers.add_parser(parsed_args_lib.CommandName.COMPARE_CMD.value),
placeholders,
)
_create_eval_parser(
subparsers.add_parser(parsed_args_lib.CommandName.EVAL_CMD.value),
placeholders,
)
return parser
def _validate_parsed_args(parsed_args: parsed_args_lib.ParsedArgs) -> None:
# If candidate_count is not set (i.e. is None), assuming the default value
# is 1.
if parsed_args.unique and (
parsed_args.model_args.candidate_count is None
or parsed_args.model_args.candidate_count == 1
):
print(
'"--unique" works across candidates only: it should be used with'
" --candidate_count set to a value greater-than one."
)
class CmdLineParser:
"""Implementation of Magics command line parser."""
# Commands
DEFAULT_CMD = parsed_args_lib.CommandName.RUN_CMD
# Post-processing operator.
PIPE_OP = "|"
@classmethod
def _split_post_processing_tokens(
cls,
tokens: Sequence[str],
) -> tuple[Sequence[str], parsed_args_lib.PostProcessingTokens]:
"""Splits inputs into the command and post processing tokens.
The command is represented as a sequence of tokens.
See comments on the PostProcessingTokens type alias.
E.g. Given: "run --temperature 0.5 | add_score | to_lower_case"
The command will be: ["run", "--temperature", "0.5"].
The post processing tokens will be: [["add_score"], ["to_lower_case"]]
Args:
tokens: The command line tokens.
Returns:
A tuple of (command line, post processing tokens).
"""
split_tokens = []
start_idx: int | None = None
for token_num, token in enumerate(tokens):
if start_idx is None:
start_idx = token_num
if token == CmdLineParser.PIPE_OP:
split_tokens.append(tokens[start_idx:token_num] if start_idx is not None else [])
start_idx = None
# Add the remaining tokens after the last PIPE_OP.
split_tokens.append(tokens[start_idx:] if start_idx is not None else [])
return split_tokens[0], split_tokens[1:]
@classmethod
def _tokenize_line(
cls, line: str
) -> tuple[Sequence[str], parsed_args_lib.PostProcessingTokens]:
"""Parses `line` and returns command line and post processing tokens."""
# Check to make sure there is a command at the start. If not, add the
# default command to the list of tokens.
tokens = shlex.split(line)
if not tokens:
tokens = [CmdLineParser.DEFAULT_CMD.value]
first_token = tokens[0]
# Add default command if the first token is not the help token.
if not first_token[0].isalpha() and first_token not in ["-h", "--help"]:
tokens = [CmdLineParser.DEFAULT_CMD.value] + tokens
# Split line into tokens and post-processing
return CmdLineParser._split_post_processing_tokens(tokens)
@classmethod
def _get_model_args(
cls, parsed_results: MutableMapping[str, Any]
) -> tuple[MutableMapping[str, Any], model_lib.ModelArguments]:
"""Extracts fields for model args from `parsed_results`.
Keys specific to model arguments will be removed from `parsed_results`.
Args:
parsed_results: A dictionary of parsed arguments (from ArgumentParser). It
will be modified in place.
Returns:
A tuple of (updated parsed_results, model arguments).
"""
model = parsed_results.pop("model", None)
temperature = parsed_results.pop("temperature", None)
candidate_count = parsed_results.pop("candidate_count", None)
model_args = model_lib.ModelArguments(
model=model,
temperature=temperature,
candidate_count=candidate_count,
)
return parsed_results, model_args
def parse_line(
self,
line: str,
placeholders: AbstractSet[str] | None = None,
) -> tuple[parsed_args_lib.ParsedArgs, parsed_args_lib.PostProcessingTokens]:
"""Parses the commandline and returns ParsedArgs and post-processing tokens.
Args:
line: The line to parse (usually contents from cell Magics).
placeholders: Placeholders from prompts in the cell contents.
Returns:
A tuple of (parsed_args, post_processing_tokens).
"""
tokens, post_processing_tokens = CmdLineParser._tokenize_line(line)
parsed_args = self._get_parsed_args_from_cmd_line_tokens(
tokens=tokens, placeholders=placeholders
)
# Special-case for "compare" command: because the prompts are compiled into
# the left- and right-hand side functions rather than in the cell body, we
# cannot examine the cell body to get the placeholders.
#
# Instead we parse the command line twice: once to get the left- and right-
# functions, then we query the functions for their placeholders, then
# parse the commandline again to validate the inputs.
if parsed_args.cmd == parsed_args_lib.CommandName.COMPARE_CMD:
assert parsed_args.lhs_name_and_fn is not None
assert parsed_args.rhs_name_and_fn is not None
_, lhs_fn = parsed_args.lhs_name_and_fn
_, rhs_fn = parsed_args.rhs_name_and_fn
parsed_args = self._get_parsed_args_from_cmd_line_tokens(
tokens=tokens,
placeholders=frozenset(lhs_fn.get_placeholders()).union(rhs_fn.get_placeholders()),
)
_validate_parsed_args(parsed_args)
for expr in post_processing_tokens:
post_process_utils.validate_one_post_processing_expression(expr)
return parsed_args, post_processing_tokens
def _get_parsed_args_from_cmd_line_tokens(
self,
tokens: Sequence[str],
placeholders: AbstractSet[str] | None,
) -> parsed_args_lib.ParsedArgs:
"""Returns ParsedArgs from a tokenized command line."""
# Create a new parser to avoid reusing the temporary argparse.Namespace
# object.
results = _create_parser(placeholders).parse_args(tokens)
results_dict = vars(results)
results_dict["cmd"] = parsed_args_lib.CommandName(results_dict["cmd"])
results_dict, model_args = CmdLineParser._get_model_args(results_dict)
results_dict["model_args"] = model_args
return parsed_args_lib.ParsedArgs(**results_dict)