google/generativeai/notebook/flag_def.py (283 lines of code) (raw):

# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Classes that define arguments for populating ArgumentParser. The argparse module's ArgumentParser.add_argument() takes several parameters and is quite customizable. However this can lead to bugs where arguments do not behave as expected. For better ease-of-use and better testability, define a set of classes for the types of flags used by LLM Magics. Sample usage: str_flag = SingleValueFlagDef(name="title", required=True) enum_flag = EnumFlagDef(name="colors", required=True, enum_type=ColorsEnum) str_flag.add_argument_to_parser(my_parser) enum_flag.add_argument_to_parser(my_parser) """ from __future__ import annotations import abc import argparse import dataclasses import enum from typing import Any, Callable, Sequence, Tuple, Union from google.generativeai.notebook.lib import llmfn_inputs_source from google.generativeai.notebook.lib import llmfn_outputs # These are the intermediate types that argparse.ArgumentParser.parse_args() # will pass command line arguments into. _PARSETYPES = Union[str, int, float] # These are the final result types that the intermediate parsed values will be # converted into. It is a superset of _PARSETYPES because we support converting # the parsed type into a more precise type, e.g. from str to Enum. _DESTTYPES = Union[ _PARSETYPES, enum.Enum, Tuple[str, Callable[[str, str], Any]], Sequence[str], # For --compare_fn llmfn_inputs_source.LLMFnInputsSource, # For --ground_truth llmfn_outputs.LLMFnOutputsSink, # For --inputs # For --outputs ] # The signature of a function that converts a command line argument from the # intermediate parsed type to the result type. _PARSEFN = Callable[[_PARSETYPES], _DESTTYPES] def _get_type_name(x: type[Any]) -> str: try: return x.__name__ except AttributeError: return str(x) def _validate_flag_name(name: str) -> str: """Validation for long and short names for flags.""" if not name: raise ValueError("Cannot be empty") if name[0] == "-": raise ValueError("Cannot start with dash") return name @dataclasses.dataclass(frozen=True) class FlagDef(abc.ABC): """Abstract base class for flag definitions. Attributes: name: Long name, e.g. "colors" will define the flag "--colors". required: Whether the flag must be provided on the command line. short_name: Optional short name. parse_type: The type that ArgumentParser should parse the command line argument to. dest_type: The type that the parsed value is converted to. This is used when we want ArgumentParser to parse as one type, then convert to a different type. E.g. for enums we parse as "str" then convert to the desired enum type in order to provide cleaner help messages. parse_to_dest_type_fn: If provided, this function will be used to convert the value from `parse_type` to `dest_type`. This can be used for validation as well. choices: If provided, limit the set of acceptable values to these choices. help_msg: If provided, adds help message when -h is used in the command line. """ name: str required: bool = False short_name: str | None = None parse_type: type[_PARSETYPES] = str dest_type: type[_DESTTYPES] | None = None parse_to_dest_type_fn: _PARSEFN | None = None choices: list[_PARSETYPES] | None = None help_msg: str | None = None @abc.abstractmethod def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None: """Adds this flag as an argument to `parser`. Child classes should implement this as a call to parser.add_argument() with the appropriate parameters. Args: parser: The parser to which this argument will be added. """ @abc.abstractmethod def _do_additional_validation(self) -> None: """For child classes to do additional validation.""" def _get_dest_type(self) -> type[_DESTTYPES]: """Returns the final converted type.""" return self.parse_type if self.dest_type is None else self.dest_type def _get_parse_to_dest_type_fn( self, ) -> _PARSEFN: """Returns a function to convert from parse_type to dest_type.""" if self.parse_to_dest_type_fn is not None: return self.parse_to_dest_type_fn dest_type = self._get_dest_type() if dest_type == self.parse_type: return lambda x: x else: return dest_type def __post_init__(self): _validate_flag_name(self.name) if self.short_name is not None: _validate_flag_name(self.short_name) self._do_additional_validation() def _has_non_default_value( namespace: argparse.Namespace, dest: str, has_default: bool = False, default_value: Any = None, ) -> bool: """Returns true if `namespace.dest` is set to a non-default value. Args: namespace: The Namespace that is populated by ArgumentParser. dest: The attribute in the Namespace to be populated. has_default: "None" is a valid default value so we use an additional `has_default` boolean to indicate that `default_value` is present. default_value: The default value to use when `has_default` is True. Returns: Whether namespace.dest is set to something other than the default value. """ if not hasattr(namespace, dest): return False if not has_default: # No default value provided so `namespace.dest` cannot possibly be equal to # the default value. return True return getattr(namespace, dest) != default_value class _SingleValueStoreAction(argparse.Action): """Custom Action for storing a value in an argparse.Namespace. This action checks that the flag is specified at-most once. """ def __init__( self, option_strings, dest, dest_type: type[Any], parse_to_dest_type_fn: _PARSEFN, **kwargs, ): super().__init__(option_strings, dest, **kwargs) self._dest_type = dest_type self._parse_to_dest_type_fn = parse_to_dest_type_fn def __call__( self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: str | Sequence[Any] | None, option_string: str | None = None, ): # Because `nargs` is set to 1, `values` must be a Sequence, rather # than a string. assert not isinstance(values, str) and not isinstance(values, bytes) if _has_non_default_value( namespace, self.dest, has_default=hasattr(self, "default"), default_value=getattr(self, "default"), ): raise argparse.ArgumentError(self, "Cannot set {} more than once".format(option_string)) try: converted_value = self._parse_to_dest_type_fn(values[0]) except Exception as e: raise argparse.ArgumentError( self, 'Error with value "{}", got {}: {}'.format(values[0], _get_type_name(type(e)), e), ) if not isinstance(converted_value, self._dest_type): raise RuntimeError( "Converted to wrong type, expected {} got {}".format( _get_type_name(self._dest_type), _get_type_name(type(converted_value)), ) ) setattr(namespace, self.dest, converted_value) class _MultiValuesAppendAction(argparse.Action): """Custom Action for appending values in an argparse.Namespace. This action checks that the flag is specified at-most once. """ def __init__( self, option_strings, dest, dest_type: type[Any], parse_to_dest_type_fn: _PARSEFN, **kwargs, ): super().__init__(option_strings, dest, **kwargs) self._dest_type = dest_type self._parse_to_dest_type_fn = parse_to_dest_type_fn def __call__( self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: str | Sequence[Any] | None, option_string: str | None = None, ): # Because `nargs` is set to "+", `values` must be a Sequence, rather # than a string. assert not isinstance(values, str) and not isinstance(values, bytes) curr_value = getattr(namespace, self.dest) if curr_value: raise argparse.ArgumentError(self, "Cannot set {} more than once".format(option_string)) for value in values: try: converted_value = self._parse_to_dest_type_fn(value) except Exception as e: raise argparse.ArgumentError( self, 'Error with value "{}", got {}: {}'.format( values[0], _get_type_name(type(e)), e ), ) if not isinstance(converted_value, self._dest_type): raise RuntimeError( "Converted to wrong type, expected {} got {}".format( self._dest_type, type(converted_value) ) ) if converted_value in curr_value: raise argparse.ArgumentError(self, 'Duplicate values "{}"'.format(value)) curr_value.append(converted_value) class _BooleanValueStoreAction(argparse.Action): """Custom Action for setting a boolean value in argparse.Namespace. The boolean flag expects the default to be False and will set the value to True. This action checks that the flag is specified at-most once. """ def __init__( self, option_strings, dest, **kwargs, ): super().__init__(option_strings, dest, **kwargs) def __call__( self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: str | Sequence[Any] | None, option_string: str | None = None, ): if _has_non_default_value( namespace, self.dest, has_default=True, default_value=False, ): raise argparse.ArgumentError(self, "Cannot set {} more than once".format(option_string)) setattr(namespace, self.dest, True) @dataclasses.dataclass(frozen=True) class SingleValueFlagDef(FlagDef): """Definition for a flag that takes a single value. Sample usage: # This defines a flag that can be specified on the command line as: # --count=10 flag = SingleValueFlagDef(name="count", parse_type=int, required=True) flag.add_argument_to_parser(argument_parser) Attributes: default_value: Default value for optional flags. """ class _DefaultValue(enum.Enum): """Special value to represent "no value provided". "None" can be used as a default value, so in order to differentiate between "None" and "no value provided", create a special value for "no value provided". """ NOT_SET = None default_value: _DESTTYPES | _DefaultValue | None = _DefaultValue.NOT_SET def _has_default_value(self) -> bool: """Returns whether `default_value` has been provided.""" return self.default_value != SingleValueFlagDef._DefaultValue.NOT_SET def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None: args = ["--" + self.name] if self.short_name is not None: args += ["-" + self.short_name] kwargs = {} if self._has_default_value(): kwargs["default"] = self.default_value if self.choices is not None: kwargs["choices"] = self.choices if self.help_msg is not None: kwargs["help"] = self.help_msg parser.add_argument( *args, action=_SingleValueStoreAction, type=self.parse_type, dest_type=self._get_dest_type(), parse_to_dest_type_fn=self._get_parse_to_dest_type_fn(), required=self.required, nargs=1, **kwargs, ) def _do_additional_validation(self) -> None: if self.required: if self._has_default_value(): raise ValueError("Required flags cannot have default value") else: if not self._has_default_value(): raise ValueError("Optional flags must have a default value") if self._has_default_value() and self.default_value is not None: if not isinstance(self.default_value, self._get_dest_type()): raise ValueError("Default value must be of the same type as the destination type") class EnumFlagDef(SingleValueFlagDef): """Definition for a flag that takes a value from an Enum. Sample usage: # This defines a flag that can be specified on the command line as: # --color=red flag = SingleValueFlagDef(name="color", enum_type=ColorsEnum, required=True) flag.add_argument_to_parser(argument_parser) """ def __init__(self, *args, enum_type: type[enum.Enum], **kwargs): if not issubclass(enum_type, enum.Enum): raise TypeError('"enum_type" must be of type Enum') # These properties are set by "enum_type" so don"t let the caller set them. if "parse_type" in kwargs: raise ValueError('Cannot set "parse_type" for EnumFlagDef; set "enum_type" instead') kwargs["parse_type"] = str if "dest_type" in kwargs: raise ValueError('Cannot set "dest_type" for EnumFlagDef; set "enum_type" instead') kwargs["dest_type"] = enum_type if "choices" in kwargs: # Verify that entries in `choices` are valid enum values. for x in kwargs["choices"]: try: enum_type(x) except ValueError: raise ValueError('Invalid value in "choices": "{}"'.format(x)) from None else: kwargs["choices"] = [x.value for x in enum_type] super().__init__(*args, **kwargs) class MultiValuesFlagDef(FlagDef): """Definition for a flag that takes multiple values. Sample usage: # This defines a flag that can be specified on the command line as: # --colors=red green blue flag = MultiValuesFlagDef(name="colors", parse_type=str, required=True) flag.add_argument_to_parser(argument_parser) """ def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None: args = ["--" + self.name] if self.short_name is not None: args += ["-" + self.short_name] kwargs = {} if self.choices is not None: kwargs["choices"] = self.choices if self.help_msg is not None: kwargs["help"] = self.help_msg parser.add_argument( *args, action=_MultiValuesAppendAction, type=self.parse_type, dest_type=self._get_dest_type(), parse_to_dest_type_fn=self._get_parse_to_dest_type_fn(), required=self.required, default=[], nargs="+", **kwargs, ) def _do_additional_validation(self) -> None: # No additional validation needed. pass @dataclasses.dataclass(frozen=True) class BooleanFlagDef(FlagDef): """Definition for a Boolean flag. A boolean flag is always optional with a default value of False. The flag does not take any values. Specifying the flag on the commandline will set it to True. """ def _do_additional_validation(self) -> None: if self.dest_type is not None: raise ValueError("dest_type cannot be set for BooleanFlagDef") if self.parse_to_dest_type_fn is not None: raise ValueError("parse_to_dest_type_fn cannot be set for BooleanFlagDef") if self.choices is not None: raise ValueError("choices cannot be set for BooleanFlagDef") def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None: args = ["--" + self.name] if self.short_name is not None: args += ["-" + self.short_name] kwargs = {} if self.help_msg is not None: kwargs["help"] = self.help_msg parser.add_argument( *args, action=_BooleanValueStoreAction, type=bool, required=False, default=False, nargs=0, **kwargs, )