google/generativeai/notebook/lib/llm_function.py (309 lines of code) (raw):
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LLMFunction."""
from __future__ import annotations
import abc
import dataclasses
from typing import (
AbstractSet,
Any,
Callable,
Iterable,
Mapping,
Optional,
Sequence,
Union,
)
from google.generativeai.notebook.lib import llmfn_input_utils
from google.generativeai.notebook.lib import llmfn_output_row
from google.generativeai.notebook.lib import llmfn_outputs
from google.generativeai.notebook.lib import llmfn_post_process
from google.generativeai.notebook.lib import llmfn_post_process_cmds
from google.generativeai.notebook.lib import model as model_lib
from google.generativeai.notebook.lib import prompt_utils
# In the same spirit as post-processing functions (see: llmfn_post_process.py),
# we keep the LLM functions more flexible by providing the entire left- and
# right-hand side rows to the user-defined comparison function.
#
# Possible use-cases include adding a scoring function as a post-process
# command, then comparing the scores.
CompareFn = Callable[
[llmfn_output_row.LLMFnOutputRowView, llmfn_output_row.LLMFnOutputRowView],
Any,
]
def _is_equal_fn(
lhs: llmfn_output_row.LLMFnOutputRowView,
rhs: llmfn_output_row.LLMFnOutputRowView,
) -> bool:
"""Default function used when comparing outputs."""
return lhs.result_value() == rhs.result_value()
def _convert_compare_fn_to_batch_add_fn(
fn: Callable[
[
llmfn_output_row.LLMFnOutputRowView,
llmfn_output_row.LLMFnOutputRowView,
],
Any,
],
) -> llmfn_post_process.LLMCompareFnPostProcessBatchAddFn:
"""Vectorize a single-row-based comparison function."""
def _fn(
lhs_and_rhs_rows: Sequence[
tuple[
llmfn_output_row.LLMFnOutputRowView,
llmfn_output_row.LLMFnOutputRowView,
]
],
) -> Sequence[Any]:
return [fn(lhs, rhs) for lhs, rhs in lhs_and_rhs_rows]
return _fn
@dataclasses.dataclass
class _PromptInfo:
prompt_num: int
prompt: str
input_num: int
prompt_vars: Mapping[str, str]
model_input: str
def _generate_prompts(
prompts: Sequence[str], inputs: llmfn_input_utils.LLMFunctionInputs | None
) -> Iterable[_PromptInfo]:
"""Generate a tuple of fields needed for processing prompts.
Args:
prompts: A list of prompts, with optional keyword placeholders.
inputs: A list of key/value pairs to substitute into placeholders in
`prompts`.
Yields:
A _PromptInfo instance.
"""
normalized_inputs: Sequence[Mapping[str, str]] = []
if inputs is not None:
normalized_inputs = llmfn_input_utils.to_normalized_inputs(inputs)
# Must have at least one entry so that we execute the prompt at least once.
if not normalized_inputs:
normalized_inputs = [{}]
for prompt_num, prompt in enumerate(prompts):
for input_num, prompt_vars in enumerate(normalized_inputs):
# Perform keyword substitution on the prompt based on `prompt_vars`.
model_input = prompt.format(**prompt_vars)
yield _PromptInfo(
prompt_num=prompt_num,
prompt=prompt,
input_num=input_num,
prompt_vars=prompt_vars,
model_input=model_input,
)
class LLMFunction(
Callable[
[Union[llmfn_input_utils.LLMFunctionInputs, None]],
llmfn_outputs.LLMFnOutputs,
],
metaclass=abc.ABCMeta,
):
"""Base class for LLMFunctionImpl and LLMCompareFunction."""
def __init__(
self,
outputs_ipython_display_fn: Callable[[llmfn_outputs.LLMFnOutputs], None] | None = None,
):
"""Constructor.
Args:
outputs_ipython_display_fn: Optional function that will be used to
override how the outputs of this LLMFunction will be displayed in a
notebook (See further documentation in LLMFnOutputs.__init__().)
"""
self._post_process_cmds: list[llmfn_post_process_cmds.LLMFnPostProcessCommand] = []
self._outputs_ipython_display_fn = outputs_ipython_display_fn
@abc.abstractmethod
def get_placeholders(self) -> AbstractSet[str]:
"""Returns the placeholders that should be present in inputs for this function."""
@abc.abstractmethod
def _call_impl(
self, inputs: llmfn_input_utils.LLMFunctionInputs | None
) -> Sequence[llmfn_outputs.LLMFnOutputEntry]:
"""Concrete implementation of __call__()."""
def __call__(
self, inputs: llmfn_input_utils.LLMFunctionInputs | None = None
) -> llmfn_outputs.LLMFnOutputs:
"""Runs and returns results based on `inputs`."""
outputs = self._call_impl(inputs)
return llmfn_outputs.LLMFnOutputs(
outputs=outputs, ipython_display_fn=self._outputs_ipython_display_fn
)
def add_post_process_reorder_fn(
self, name: str, fn: llmfn_post_process.LLMFnPostProcessBatchReorderFn
) -> LLMFunction:
self._post_process_cmds.append(
llmfn_post_process_cmds.LLMFnPostProcessReorderCommand(name=name, fn=fn)
)
return self
def add_post_process_add_fn(
self,
name: str,
fn: llmfn_post_process.LLMFnPostProcessBatchAddFn,
) -> LLMFunction:
self._post_process_cmds.append(
llmfn_post_process_cmds.LLMFnPostProcessAddCommand(name=name, fn=fn)
)
return self
def add_post_process_replace_fn(
self,
name: str,
fn: llmfn_post_process.LLMFnPostProcessBatchReplaceFn,
) -> LLMFunction:
self._post_process_cmds.append(
llmfn_post_process_cmds.LLMFnPostProcessReplaceCommand(name=name, fn=fn)
)
return self
class LLMFunctionImpl(LLMFunction):
"""Callable class that executes the contents of a Magics cell.
An LLMFunction is constructed from the Magics command line and cell contents
specified by the user. It is defined by:
- A model instance,
- Model arguments
- A prompt template (e.g. "the opposite of hot is {word}") with an optional
keyword placeholder.
The LLMFunction takes as its input a sequence of dictionaries containing
values for keyword replacement, e.g. [{"word": "hot"}, {"word": "tall"}].
This will cause the model to be executed with the following prompts:
"The opposite of hot is"
"The opposite of tall is"
The results will be returned in a LLMFnOutputs instance.
"""
def __init__(
self,
model: model_lib.AbstractModel,
prompts: Sequence[str],
model_args: model_lib.ModelArguments | None = None,
outputs_ipython_display_fn: Callable[[llmfn_outputs.LLMFnOutputs], None] | None = None,
):
"""Constructor.
Args:
model: The model that the prompts will execute on.
prompts: A sequence of prompt templates with optional placeholders. The
placeholders will be replaced by the inputs passed into this function.
model_args: Optional set of model arguments to configure how the model
executes the prompts.
outputs_ipython_display_fn: See documentation in LLMFunction.__init__().
"""
super().__init__(outputs_ipython_display_fn=outputs_ipython_display_fn)
self._model = model
self._prompts = prompts
self._model_args = model_lib.ModelArguments() if model_args is None else model_args
# Compute placeholders.
self._placeholders = frozenset({})
for prompt in self._prompts:
self._placeholders = self._placeholders.union(prompt_utils.get_placeholders(prompt))
def _run_post_processing_cmds(
self, results: Sequence[llmfn_output_row.LLMFnOutputRow]
) -> Sequence[llmfn_output_row.LLMFnOutputRow]:
"""Runs post-processing commands over `results`."""
for cmd in self._post_process_cmds:
try:
if isinstance(cmd, llmfn_post_process_cmds.LLMFnImplPostProcessCommand):
results = cmd.run(results)
else:
raise llmfn_post_process.PostProcessExecutionError(
"Unsupported post-process command type: {}".format(type(cmd))
)
except llmfn_post_process.PostProcessExecutionError:
raise
except RuntimeError as e:
raise llmfn_post_process.PostProcessExecutionError(
'Error executing "{}", got {}: {}'.format(cmd.name(), type(e).__name__, e)
)
return results
def get_placeholders(self) -> AbstractSet[str]:
return self._placeholders
def _call_impl(
self, inputs: llmfn_input_utils.LLMFunctionInputs | None
) -> Sequence[llmfn_outputs.LLMFnOutputEntry]:
results: list[llmfn_outputs.LLMFnOutputEntry] = []
for info in _generate_prompts(prompts=self._prompts, inputs=inputs):
model_results = self._model.call_model(
model_input=info.model_input, model_args=self._model_args
)
output_rows: list[llmfn_output_row.LLMFnOutputRow] = []
for result_num, text_result in enumerate(model_results.text_results):
output_rows.append(
llmfn_output_row.LLMFnOutputRow(
data={
llmfn_outputs.ColumnNames.RESULT_NUM: result_num,
llmfn_outputs.ColumnNames.TEXT_RESULT: text_result,
},
result_type=str,
)
)
results.append(
llmfn_outputs.LLMFnOutputEntry(
prompt_num=info.prompt_num,
input_num=info.input_num,
prompt=info.prompt,
prompt_vars=info.prompt_vars,
model_input=info.model_input,
model_results=model_results,
output_rows=self._run_post_processing_cmds(output_rows),
)
)
return results
class LLMCompareFunction(LLMFunction):
"""LLMFunction for comparisons.
LLMCompareFunction runs an input over a pair of LLMFunctions and compares the
result.
"""
def __init__(
self,
lhs_name_and_fn: tuple[str, LLMFunction],
rhs_name_and_fn: tuple[str, LLMFunction],
compare_name_and_fns: Sequence[tuple[str, CompareFn]] | None = None,
outputs_ipython_display_fn: Callable[[llmfn_outputs.LLMFnOutputs], None] | None = None,
):
"""Constructor.
Args:
lhs_name_and_fn: Name and function for the left-hand side of the
comparison.
rhs_name_and_fn: Name and function for the right-hand side of the
comparison.
compare_name_and_fns: Optional names and functions for comparing the
results of the left- and right-hand sides.
outputs_ipython_display_fn: See documentation in LLMFunction.__init__().
"""
super().__init__(outputs_ipython_display_fn=outputs_ipython_display_fn)
self._lhs_name: str = lhs_name_and_fn[0]
self._lhs_fn: LLMFunction = lhs_name_and_fn[1]
self._rhs_name: str = rhs_name_and_fn[0]
self._rhs_fn: LLMFunction = rhs_name_and_fn[1]
self._placeholders = frozenset(self._lhs_fn.get_placeholders()).union(
self._rhs_fn.get_placeholders()
)
if not compare_name_and_fns:
self._result_name = "is_equal"
self._result_compare_fn = _is_equal_fn
else:
# Assume the last entry in `compare_name_and_fns` is the one that
# produces value for the result cell.
name, fn = compare_name_and_fns[-1]
self._result_name = name
self._result_compare_fn = fn
# Treat the other compare_fns as post-processing operators.
for name, cmp_fn in compare_name_and_fns[:-1]:
self.add_compare_post_process_add_fn(
name=name, fn=_convert_compare_fn_to_batch_add_fn(cmp_fn)
)
def _run_post_processing_cmds(
self,
lhs_output_rows: Sequence[llmfn_output_row.LLMFnOutputRow],
rhs_output_rows: Sequence[llmfn_output_row.LLMFnOutputRow],
results: Sequence[llmfn_output_row.LLMFnOutputRow],
) -> Sequence[llmfn_output_row.LLMFnOutputRow]:
"""Runs post-processing commands over `results`."""
for cmd in self._post_process_cmds:
try:
if isinstance(cmd, llmfn_post_process_cmds.LLMFnImplPostProcessCommand):
results = cmd.run(results)
elif isinstance(cmd, llmfn_post_process_cmds.LLMCompareFnPostProcessCommand):
results = cmd.run(list(zip(lhs_output_rows, rhs_output_rows, results)))
else:
raise RuntimeError(
"Unsupported post-process command type: {}".format(type(cmd))
)
except llmfn_post_process.PostProcessExecutionError:
raise
except RuntimeError as e:
raise llmfn_post_process.PostProcessExecutionError(
'Error executing "{}", got {}: {}'.format(cmd.name(), type(e).__name__, e)
)
return results
def get_placeholders(self) -> AbstractSet[str]:
return self._placeholders
def _call_impl(
self, inputs: llmfn_input_utils.LLMFunctionInputs | None
) -> Sequence[llmfn_outputs.LLMFnOutputEntry]:
lhs_results = self._lhs_fn(inputs)
rhs_results = self._rhs_fn(inputs)
# Combine the results.
outputs: list[llmfn_outputs.LLMFnOutputEntry] = []
for lhs_entry, rhs_entry in zip(lhs_results, rhs_results):
if lhs_entry.prompt_num != rhs_entry.prompt_num:
raise RuntimeError(
"Prompt num mismatch: {} vs {}".format(
lhs_entry.prompt_num, rhs_entry.prompt_num
)
)
if lhs_entry.input_num != rhs_entry.input_num:
raise RuntimeError(
"Input num mismatch: {} vs {}".format(lhs_entry.input_num, rhs_entry.input_num)
)
if lhs_entry.prompt_vars != rhs_entry.prompt_vars:
raise RuntimeError(
"Prompt vars mismatch: {} vs {}".format(
lhs_entry.prompt_vars, rhs_entry.prompt_vars
)
)
# The two functions may have different numbers of results due to
# options like candidate_count, so we can only compare up to the
# minimum of the two.
num_output_rows = min(len(lhs_entry.output_rows), len(rhs_entry.output_rows))
lhs_output_rows = lhs_entry.output_rows[:num_output_rows]
rhs_output_rows = rhs_entry.output_rows[:num_output_rows]
output_rows: list[llmfn_output_row.LLMFnOutputRow] = []
for result_num, lhs_and_rhs_output_row in enumerate(
zip(lhs_output_rows, rhs_output_rows)
):
lhs_output_row, rhs_output_row = lhs_and_rhs_output_row
# Combine cells from lhs_output_row and rhs_output_row into a
# single row.
# Although it is possible for RESULT_NUM (the index of each
# text_result if a prompt produces multiple text_results) to be
# different between the left and right sides, we ignore their
# RESULT_NUM entries and write our own.
row_data: dict[str, Any] = {
llmfn_outputs.ColumnNames.RESULT_NUM: result_num,
self._result_name: self._result_compare_fn(lhs_output_row, rhs_output_row),
}
output_row = llmfn_output_row.LLMFnOutputRow(data=row_data, result_type=Any)
# Add the prompt vars.
output_row.add(llmfn_outputs.ColumnNames.PROMPT_VARS, lhs_entry.prompt_vars)
# Add the results from the left-hand side and right-hand side.
for name, row in [
(self._lhs_name, lhs_output_row),
(self._rhs_name, rhs_output_row),
]:
for k, v in row.items():
if k != llmfn_outputs.ColumnNames.RESULT_NUM:
# We use LLMFnOutputRow.add() because it handles column
# name collisions.
output_row.add("{}_{}".format(name, k), v)
output_rows.append(output_row)
outputs.append(
llmfn_outputs.LLMFnOutputEntry(
prompt_num=lhs_entry.prompt_num,
input_num=lhs_entry.input_num,
prompt_vars=lhs_entry.prompt_vars,
output_rows=self._run_post_processing_cmds(
lhs_output_rows=lhs_output_rows,
rhs_output_rows=rhs_output_rows,
results=output_rows,
),
)
)
return outputs
def add_compare_post_process_add_fn(
self,
name: str,
fn: llmfn_post_process.LLMCompareFnPostProcessBatchAddFn,
) -> LLMFunction:
self._post_process_cmds.append(
llmfn_post_process_cmds.LLMCompareFnPostProcessAddCommand(name=name, fn=fn)
)
return self