google/generativeai/notebook/lib/llm_function.py (309 lines of code) (raw):

# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """LLMFunction.""" from __future__ import annotations import abc import dataclasses from typing import ( AbstractSet, Any, Callable, Iterable, Mapping, Optional, Sequence, Union, ) from google.generativeai.notebook.lib import llmfn_input_utils from google.generativeai.notebook.lib import llmfn_output_row from google.generativeai.notebook.lib import llmfn_outputs from google.generativeai.notebook.lib import llmfn_post_process from google.generativeai.notebook.lib import llmfn_post_process_cmds from google.generativeai.notebook.lib import model as model_lib from google.generativeai.notebook.lib import prompt_utils # In the same spirit as post-processing functions (see: llmfn_post_process.py), # we keep the LLM functions more flexible by providing the entire left- and # right-hand side rows to the user-defined comparison function. # # Possible use-cases include adding a scoring function as a post-process # command, then comparing the scores. CompareFn = Callable[ [llmfn_output_row.LLMFnOutputRowView, llmfn_output_row.LLMFnOutputRowView], Any, ] def _is_equal_fn( lhs: llmfn_output_row.LLMFnOutputRowView, rhs: llmfn_output_row.LLMFnOutputRowView, ) -> bool: """Default function used when comparing outputs.""" return lhs.result_value() == rhs.result_value() def _convert_compare_fn_to_batch_add_fn( fn: Callable[ [ llmfn_output_row.LLMFnOutputRowView, llmfn_output_row.LLMFnOutputRowView, ], Any, ], ) -> llmfn_post_process.LLMCompareFnPostProcessBatchAddFn: """Vectorize a single-row-based comparison function.""" def _fn( lhs_and_rhs_rows: Sequence[ tuple[ llmfn_output_row.LLMFnOutputRowView, llmfn_output_row.LLMFnOutputRowView, ] ], ) -> Sequence[Any]: return [fn(lhs, rhs) for lhs, rhs in lhs_and_rhs_rows] return _fn @dataclasses.dataclass class _PromptInfo: prompt_num: int prompt: str input_num: int prompt_vars: Mapping[str, str] model_input: str def _generate_prompts( prompts: Sequence[str], inputs: llmfn_input_utils.LLMFunctionInputs | None ) -> Iterable[_PromptInfo]: """Generate a tuple of fields needed for processing prompts. Args: prompts: A list of prompts, with optional keyword placeholders. inputs: A list of key/value pairs to substitute into placeholders in `prompts`. Yields: A _PromptInfo instance. """ normalized_inputs: Sequence[Mapping[str, str]] = [] if inputs is not None: normalized_inputs = llmfn_input_utils.to_normalized_inputs(inputs) # Must have at least one entry so that we execute the prompt at least once. if not normalized_inputs: normalized_inputs = [{}] for prompt_num, prompt in enumerate(prompts): for input_num, prompt_vars in enumerate(normalized_inputs): # Perform keyword substitution on the prompt based on `prompt_vars`. model_input = prompt.format(**prompt_vars) yield _PromptInfo( prompt_num=prompt_num, prompt=prompt, input_num=input_num, prompt_vars=prompt_vars, model_input=model_input, ) class LLMFunction( Callable[ [Union[llmfn_input_utils.LLMFunctionInputs, None]], llmfn_outputs.LLMFnOutputs, ], metaclass=abc.ABCMeta, ): """Base class for LLMFunctionImpl and LLMCompareFunction.""" def __init__( self, outputs_ipython_display_fn: Callable[[llmfn_outputs.LLMFnOutputs], None] | None = None, ): """Constructor. Args: outputs_ipython_display_fn: Optional function that will be used to override how the outputs of this LLMFunction will be displayed in a notebook (See further documentation in LLMFnOutputs.__init__().) """ self._post_process_cmds: list[llmfn_post_process_cmds.LLMFnPostProcessCommand] = [] self._outputs_ipython_display_fn = outputs_ipython_display_fn @abc.abstractmethod def get_placeholders(self) -> AbstractSet[str]: """Returns the placeholders that should be present in inputs for this function.""" @abc.abstractmethod def _call_impl( self, inputs: llmfn_input_utils.LLMFunctionInputs | None ) -> Sequence[llmfn_outputs.LLMFnOutputEntry]: """Concrete implementation of __call__().""" def __call__( self, inputs: llmfn_input_utils.LLMFunctionInputs | None = None ) -> llmfn_outputs.LLMFnOutputs: """Runs and returns results based on `inputs`.""" outputs = self._call_impl(inputs) return llmfn_outputs.LLMFnOutputs( outputs=outputs, ipython_display_fn=self._outputs_ipython_display_fn ) def add_post_process_reorder_fn( self, name: str, fn: llmfn_post_process.LLMFnPostProcessBatchReorderFn ) -> LLMFunction: self._post_process_cmds.append( llmfn_post_process_cmds.LLMFnPostProcessReorderCommand(name=name, fn=fn) ) return self def add_post_process_add_fn( self, name: str, fn: llmfn_post_process.LLMFnPostProcessBatchAddFn, ) -> LLMFunction: self._post_process_cmds.append( llmfn_post_process_cmds.LLMFnPostProcessAddCommand(name=name, fn=fn) ) return self def add_post_process_replace_fn( self, name: str, fn: llmfn_post_process.LLMFnPostProcessBatchReplaceFn, ) -> LLMFunction: self._post_process_cmds.append( llmfn_post_process_cmds.LLMFnPostProcessReplaceCommand(name=name, fn=fn) ) return self class LLMFunctionImpl(LLMFunction): """Callable class that executes the contents of a Magics cell. An LLMFunction is constructed from the Magics command line and cell contents specified by the user. It is defined by: - A model instance, - Model arguments - A prompt template (e.g. "the opposite of hot is {word}") with an optional keyword placeholder. The LLMFunction takes as its input a sequence of dictionaries containing values for keyword replacement, e.g. [{"word": "hot"}, {"word": "tall"}]. This will cause the model to be executed with the following prompts: "The opposite of hot is" "The opposite of tall is" The results will be returned in a LLMFnOutputs instance. """ def __init__( self, model: model_lib.AbstractModel, prompts: Sequence[str], model_args: model_lib.ModelArguments | None = None, outputs_ipython_display_fn: Callable[[llmfn_outputs.LLMFnOutputs], None] | None = None, ): """Constructor. Args: model: The model that the prompts will execute on. prompts: A sequence of prompt templates with optional placeholders. The placeholders will be replaced by the inputs passed into this function. model_args: Optional set of model arguments to configure how the model executes the prompts. outputs_ipython_display_fn: See documentation in LLMFunction.__init__(). """ super().__init__(outputs_ipython_display_fn=outputs_ipython_display_fn) self._model = model self._prompts = prompts self._model_args = model_lib.ModelArguments() if model_args is None else model_args # Compute placeholders. self._placeholders = frozenset({}) for prompt in self._prompts: self._placeholders = self._placeholders.union(prompt_utils.get_placeholders(prompt)) def _run_post_processing_cmds( self, results: Sequence[llmfn_output_row.LLMFnOutputRow] ) -> Sequence[llmfn_output_row.LLMFnOutputRow]: """Runs post-processing commands over `results`.""" for cmd in self._post_process_cmds: try: if isinstance(cmd, llmfn_post_process_cmds.LLMFnImplPostProcessCommand): results = cmd.run(results) else: raise llmfn_post_process.PostProcessExecutionError( "Unsupported post-process command type: {}".format(type(cmd)) ) except llmfn_post_process.PostProcessExecutionError: raise except RuntimeError as e: raise llmfn_post_process.PostProcessExecutionError( 'Error executing "{}", got {}: {}'.format(cmd.name(), type(e).__name__, e) ) return results def get_placeholders(self) -> AbstractSet[str]: return self._placeholders def _call_impl( self, inputs: llmfn_input_utils.LLMFunctionInputs | None ) -> Sequence[llmfn_outputs.LLMFnOutputEntry]: results: list[llmfn_outputs.LLMFnOutputEntry] = [] for info in _generate_prompts(prompts=self._prompts, inputs=inputs): model_results = self._model.call_model( model_input=info.model_input, model_args=self._model_args ) output_rows: list[llmfn_output_row.LLMFnOutputRow] = [] for result_num, text_result in enumerate(model_results.text_results): output_rows.append( llmfn_output_row.LLMFnOutputRow( data={ llmfn_outputs.ColumnNames.RESULT_NUM: result_num, llmfn_outputs.ColumnNames.TEXT_RESULT: text_result, }, result_type=str, ) ) results.append( llmfn_outputs.LLMFnOutputEntry( prompt_num=info.prompt_num, input_num=info.input_num, prompt=info.prompt, prompt_vars=info.prompt_vars, model_input=info.model_input, model_results=model_results, output_rows=self._run_post_processing_cmds(output_rows), ) ) return results class LLMCompareFunction(LLMFunction): """LLMFunction for comparisons. LLMCompareFunction runs an input over a pair of LLMFunctions and compares the result. """ def __init__( self, lhs_name_and_fn: tuple[str, LLMFunction], rhs_name_and_fn: tuple[str, LLMFunction], compare_name_and_fns: Sequence[tuple[str, CompareFn]] | None = None, outputs_ipython_display_fn: Callable[[llmfn_outputs.LLMFnOutputs], None] | None = None, ): """Constructor. Args: lhs_name_and_fn: Name and function for the left-hand side of the comparison. rhs_name_and_fn: Name and function for the right-hand side of the comparison. compare_name_and_fns: Optional names and functions for comparing the results of the left- and right-hand sides. outputs_ipython_display_fn: See documentation in LLMFunction.__init__(). """ super().__init__(outputs_ipython_display_fn=outputs_ipython_display_fn) self._lhs_name: str = lhs_name_and_fn[0] self._lhs_fn: LLMFunction = lhs_name_and_fn[1] self._rhs_name: str = rhs_name_and_fn[0] self._rhs_fn: LLMFunction = rhs_name_and_fn[1] self._placeholders = frozenset(self._lhs_fn.get_placeholders()).union( self._rhs_fn.get_placeholders() ) if not compare_name_and_fns: self._result_name = "is_equal" self._result_compare_fn = _is_equal_fn else: # Assume the last entry in `compare_name_and_fns` is the one that # produces value for the result cell. name, fn = compare_name_and_fns[-1] self._result_name = name self._result_compare_fn = fn # Treat the other compare_fns as post-processing operators. for name, cmp_fn in compare_name_and_fns[:-1]: self.add_compare_post_process_add_fn( name=name, fn=_convert_compare_fn_to_batch_add_fn(cmp_fn) ) def _run_post_processing_cmds( self, lhs_output_rows: Sequence[llmfn_output_row.LLMFnOutputRow], rhs_output_rows: Sequence[llmfn_output_row.LLMFnOutputRow], results: Sequence[llmfn_output_row.LLMFnOutputRow], ) -> Sequence[llmfn_output_row.LLMFnOutputRow]: """Runs post-processing commands over `results`.""" for cmd in self._post_process_cmds: try: if isinstance(cmd, llmfn_post_process_cmds.LLMFnImplPostProcessCommand): results = cmd.run(results) elif isinstance(cmd, llmfn_post_process_cmds.LLMCompareFnPostProcessCommand): results = cmd.run(list(zip(lhs_output_rows, rhs_output_rows, results))) else: raise RuntimeError( "Unsupported post-process command type: {}".format(type(cmd)) ) except llmfn_post_process.PostProcessExecutionError: raise except RuntimeError as e: raise llmfn_post_process.PostProcessExecutionError( 'Error executing "{}", got {}: {}'.format(cmd.name(), type(e).__name__, e) ) return results def get_placeholders(self) -> AbstractSet[str]: return self._placeholders def _call_impl( self, inputs: llmfn_input_utils.LLMFunctionInputs | None ) -> Sequence[llmfn_outputs.LLMFnOutputEntry]: lhs_results = self._lhs_fn(inputs) rhs_results = self._rhs_fn(inputs) # Combine the results. outputs: list[llmfn_outputs.LLMFnOutputEntry] = [] for lhs_entry, rhs_entry in zip(lhs_results, rhs_results): if lhs_entry.prompt_num != rhs_entry.prompt_num: raise RuntimeError( "Prompt num mismatch: {} vs {}".format( lhs_entry.prompt_num, rhs_entry.prompt_num ) ) if lhs_entry.input_num != rhs_entry.input_num: raise RuntimeError( "Input num mismatch: {} vs {}".format(lhs_entry.input_num, rhs_entry.input_num) ) if lhs_entry.prompt_vars != rhs_entry.prompt_vars: raise RuntimeError( "Prompt vars mismatch: {} vs {}".format( lhs_entry.prompt_vars, rhs_entry.prompt_vars ) ) # The two functions may have different numbers of results due to # options like candidate_count, so we can only compare up to the # minimum of the two. num_output_rows = min(len(lhs_entry.output_rows), len(rhs_entry.output_rows)) lhs_output_rows = lhs_entry.output_rows[:num_output_rows] rhs_output_rows = rhs_entry.output_rows[:num_output_rows] output_rows: list[llmfn_output_row.LLMFnOutputRow] = [] for result_num, lhs_and_rhs_output_row in enumerate( zip(lhs_output_rows, rhs_output_rows) ): lhs_output_row, rhs_output_row = lhs_and_rhs_output_row # Combine cells from lhs_output_row and rhs_output_row into a # single row. # Although it is possible for RESULT_NUM (the index of each # text_result if a prompt produces multiple text_results) to be # different between the left and right sides, we ignore their # RESULT_NUM entries and write our own. row_data: dict[str, Any] = { llmfn_outputs.ColumnNames.RESULT_NUM: result_num, self._result_name: self._result_compare_fn(lhs_output_row, rhs_output_row), } output_row = llmfn_output_row.LLMFnOutputRow(data=row_data, result_type=Any) # Add the prompt vars. output_row.add(llmfn_outputs.ColumnNames.PROMPT_VARS, lhs_entry.prompt_vars) # Add the results from the left-hand side and right-hand side. for name, row in [ (self._lhs_name, lhs_output_row), (self._rhs_name, rhs_output_row), ]: for k, v in row.items(): if k != llmfn_outputs.ColumnNames.RESULT_NUM: # We use LLMFnOutputRow.add() because it handles column # name collisions. output_row.add("{}_{}".format(name, k), v) output_rows.append(output_row) outputs.append( llmfn_outputs.LLMFnOutputEntry( prompt_num=lhs_entry.prompt_num, input_num=lhs_entry.input_num, prompt_vars=lhs_entry.prompt_vars, output_rows=self._run_post_processing_cmds( lhs_output_rows=lhs_output_rows, rhs_output_rows=rhs_output_rows, results=output_rows, ), ) ) return outputs def add_compare_post_process_add_fn( self, name: str, fn: llmfn_post_process.LLMCompareFnPostProcessBatchAddFn, ) -> LLMFunction: self._post_process_cmds.append( llmfn_post_process_cmds.LLMCompareFnPostProcessAddCommand(name=name, fn=fn) ) return self