google/generativeai/notebook/lib/llmfn_post_process_cmds.py (136 lines of code) (raw):

# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Internal representation of post-process commands for LLMFunction. This module is internal to LLMFunction and should only be used by llm_function.py. """ from __future__ import annotations import abc from typing import Sequence from google.generativeai.notebook.lib import llmfn_output_row from google.generativeai.notebook.lib import llmfn_post_process def _convert_view_to_output_row( row: llmfn_output_row.LLMFnOutputRowView, ) -> llmfn_output_row.LLMFnOutputRow: """Convenience method to convert a LLMFnOutputRowView to LLMFnOutputRow. If `row` is already a LLMFnOutputRow, return as-is for efficiency. This could potentially break encapsulation as it could let code to modify a LLMFnOutputRowView that was intended to be immutable, so it should be used with care. Args: row: An instance of LLMFnOutputRowView. Returns: An instance of LLMFnOutputRow. May be the same instance as `row` if `row` is already an instance of LLMFnOutputRow. """ if isinstance(row, llmfn_output_row.LLMFnOutputRow): return row return llmfn_output_row.LLMFnOutputRow(data=row, result_type=row.result_type()) class LLMFnPostProcessCommand(abc.ABC): """Abstract class representing post-processing commands.""" @abc.abstractmethod def name(self) -> str: """Returns the name of this post-processing command.""" class LLMFnImplPostProcessCommand(LLMFnPostProcessCommand): """Post-processing commands for LLMFunctionImpl.""" @abc.abstractmethod def run( self, rows: Sequence[llmfn_output_row.LLMFnOutputRowView] ) -> Sequence[llmfn_output_row.LLMFnOutputRow]: """Processes a batch of results and returns a new batch. Args: rows: The rows in a batch. Note that `rows` are not guaranteed to be remain unmodified. Returns: A new set of rows that should replace the batch. """ class LLMFnPostProcessReorderCommand(LLMFnImplPostProcessCommand): """A batch command processes a set of results at once. Note that a "batch" represents a set of results coming from a single prompt, as the model may produce more-than-one result for a prompt. """ def __init__(self, name: str, fn: llmfn_post_process.LLMFnPostProcessBatchReorderFn): self._name = name self._fn = fn def name(self) -> str: return self._name def run( self, rows: Sequence[llmfn_output_row.LLMFnOutputRowView], ) -> Sequence[llmfn_output_row.LLMFnOutputRow]: new_row_indices = self._fn(rows) if len(set(new_row_indices)) != len(new_row_indices): raise llmfn_post_process.PostProcessExecutionError( 'Error executing "{}": returned indices should be unique'.format(self._name) ) new_rows: list[llmfn_output_row.LLMFnOutputRow] = [] for idx in new_row_indices: if idx < 0: raise llmfn_post_process.PostProcessExecutionError( 'Error executing "{}": returned indices must be greater than or' " equal to zero, got {}".format(self._name, idx) ) if idx >= len(rows): raise llmfn_post_process.PostProcessExecutionError( 'Error executing "{}": returned indices must be less than length of' " rows (={}), got {}".format(self._name, len(rows), idx) ) new_rows.append(_convert_view_to_output_row(rows[idx])) return new_rows class LLMFnPostProcessAddCommand(LLMFnImplPostProcessCommand): """A command that adds each row with a new column. This does not change the value of the results cell. """ def __init__(self, name: str, fn: llmfn_post_process.LLMFnPostProcessBatchAddFn): self._name = name self._fn = fn def name(self) -> str: return self._name def run( self, rows: Sequence[llmfn_output_row.LLMFnOutputRowView], ) -> Sequence[llmfn_output_row.LLMFnOutputRow]: new_values = self._fn(rows) if len(new_values) != len(rows): raise llmfn_post_process.PostProcessExecutionError( 'Error executing "{}": returned length ({}) != number of input rows' " ({})".format(self._name, len(new_values), len(rows)) ) new_rows: list[llmfn_output_row.LLMFnOutputRow] = [] for new_value, row in zip(new_values, rows): new_row = _convert_view_to_output_row(row) new_row.add(key=self._name, value=new_value) new_rows.append(new_row) return new_rows class LLMFnPostProcessReplaceCommand(LLMFnImplPostProcessCommand): """A command that modifies the results in each row.""" def __init__(self, name: str, fn: llmfn_post_process.LLMFnPostProcessBatchReplaceFn): self._name = name self._fn = fn def name(self) -> str: return self._name def run( self, rows: Sequence[llmfn_output_row.LLMFnOutputRowView], ) -> Sequence[llmfn_output_row.LLMFnOutputRow]: new_values = self._fn(rows) if len(new_values) != len(rows): raise llmfn_post_process.PostProcessExecutionError( 'Error executing "{}": returned length ({}) != number of input rows' " ({})".format(self._name, len(new_values), len(rows)) ) new_rows: list[llmfn_output_row.LLMFnOutputRow] = [] for new_value, row in zip(new_values, rows): new_row = _convert_view_to_output_row(row) new_row.set_result_value(value=new_value) new_rows.append(new_row) return new_rows class LLMCompareFnPostProcessCommand(LLMFnPostProcessCommand): """Post-processing commands for LLMCompareFunction.""" @abc.abstractmethod def run( self, rows: Sequence[ tuple[ llmfn_output_row.LLMFnOutputRowView, llmfn_output_row.LLMFnOutputRowView, llmfn_output_row.LLMFnOutputRowView, ] ], ) -> Sequence[llmfn_output_row.LLMFnOutputRow]: """Processes a batch of left- and right-hand side results. Args: rows: The rows in a batch. Each row is a three-tuple containing: - The left-hand side results, - The right-hand side results, and - The current combined results Returns: A new set of rows that should replace the combined results. """ class LLMCompareFnPostProcessAddCommand(LLMCompareFnPostProcessCommand): """A command that adds each row with a new column. This does not change the value of the results cell. """ def __init__( self, name: str, fn: llmfn_post_process.LLMCompareFnPostProcessBatchAddFn, ): self._name = name self._fn = fn def name(self) -> str: return self._name def run( self, rows: Sequence[ tuple[ llmfn_output_row.LLMFnOutputRowView, llmfn_output_row.LLMFnOutputRowView, llmfn_output_row.LLMFnOutputRowView, ] ], ) -> Sequence[llmfn_output_row.LLMFnOutputRow]: new_values = self._fn([(lhs, rhs) for lhs, rhs, _ in rows]) if len(new_values) != len(rows): raise llmfn_post_process.PostProcessExecutionError( 'Error executing "{}": returned length ({}) != number of input rows' " ({})".format(self._name, len(new_values), len(rows)) ) new_rows: list[llmfn_output_row.LLMFnOutputRow] = [] for new_value, row in zip(new_values, [combined for _, _, combined in rows]): new_row = _convert_view_to_output_row(row) new_row.add(key=self._name, value=new_value) new_rows.append(new_row) return new_rows