google/generativeai/notebook/lib/llmfn_post_process_cmds.py (136 lines of code) (raw):
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Internal representation of post-process commands for LLMFunction.
This module is internal to LLMFunction and should only be used by
llm_function.py.
"""
from __future__ import annotations
import abc
from typing import Sequence
from google.generativeai.notebook.lib import llmfn_output_row
from google.generativeai.notebook.lib import llmfn_post_process
def _convert_view_to_output_row(
row: llmfn_output_row.LLMFnOutputRowView,
) -> llmfn_output_row.LLMFnOutputRow:
"""Convenience method to convert a LLMFnOutputRowView to LLMFnOutputRow.
If `row` is already a LLMFnOutputRow, return as-is for efficiency.
This could potentially break encapsulation as it could let code to modify
a LLMFnOutputRowView that was intended to be immutable, so it should be
used with care.
Args:
row: An instance of LLMFnOutputRowView.
Returns:
An instance of LLMFnOutputRow. May be the same instance as `row` if
`row` is already an instance of LLMFnOutputRow.
"""
if isinstance(row, llmfn_output_row.LLMFnOutputRow):
return row
return llmfn_output_row.LLMFnOutputRow(data=row, result_type=row.result_type())
class LLMFnPostProcessCommand(abc.ABC):
"""Abstract class representing post-processing commands."""
@abc.abstractmethod
def name(self) -> str:
"""Returns the name of this post-processing command."""
class LLMFnImplPostProcessCommand(LLMFnPostProcessCommand):
"""Post-processing commands for LLMFunctionImpl."""
@abc.abstractmethod
def run(
self, rows: Sequence[llmfn_output_row.LLMFnOutputRowView]
) -> Sequence[llmfn_output_row.LLMFnOutputRow]:
"""Processes a batch of results and returns a new batch.
Args:
rows: The rows in a batch. Note that `rows` are not guaranteed to be
remain unmodified.
Returns:
A new set of rows that should replace the batch.
"""
class LLMFnPostProcessReorderCommand(LLMFnImplPostProcessCommand):
"""A batch command processes a set of results at once.
Note that a "batch" represents a set of results coming from a single prompt,
as the model may produce more-than-one result for a prompt.
"""
def __init__(self, name: str, fn: llmfn_post_process.LLMFnPostProcessBatchReorderFn):
self._name = name
self._fn = fn
def name(self) -> str:
return self._name
def run(
self,
rows: Sequence[llmfn_output_row.LLMFnOutputRowView],
) -> Sequence[llmfn_output_row.LLMFnOutputRow]:
new_row_indices = self._fn(rows)
if len(set(new_row_indices)) != len(new_row_indices):
raise llmfn_post_process.PostProcessExecutionError(
'Error executing "{}": returned indices should be unique'.format(self._name)
)
new_rows: list[llmfn_output_row.LLMFnOutputRow] = []
for idx in new_row_indices:
if idx < 0:
raise llmfn_post_process.PostProcessExecutionError(
'Error executing "{}": returned indices must be greater than or'
" equal to zero, got {}".format(self._name, idx)
)
if idx >= len(rows):
raise llmfn_post_process.PostProcessExecutionError(
'Error executing "{}": returned indices must be less than length of'
" rows (={}), got {}".format(self._name, len(rows), idx)
)
new_rows.append(_convert_view_to_output_row(rows[idx]))
return new_rows
class LLMFnPostProcessAddCommand(LLMFnImplPostProcessCommand):
"""A command that adds each row with a new column.
This does not change the value of the results cell.
"""
def __init__(self, name: str, fn: llmfn_post_process.LLMFnPostProcessBatchAddFn):
self._name = name
self._fn = fn
def name(self) -> str:
return self._name
def run(
self,
rows: Sequence[llmfn_output_row.LLMFnOutputRowView],
) -> Sequence[llmfn_output_row.LLMFnOutputRow]:
new_values = self._fn(rows)
if len(new_values) != len(rows):
raise llmfn_post_process.PostProcessExecutionError(
'Error executing "{}": returned length ({}) != number of input rows'
" ({})".format(self._name, len(new_values), len(rows))
)
new_rows: list[llmfn_output_row.LLMFnOutputRow] = []
for new_value, row in zip(new_values, rows):
new_row = _convert_view_to_output_row(row)
new_row.add(key=self._name, value=new_value)
new_rows.append(new_row)
return new_rows
class LLMFnPostProcessReplaceCommand(LLMFnImplPostProcessCommand):
"""A command that modifies the results in each row."""
def __init__(self, name: str, fn: llmfn_post_process.LLMFnPostProcessBatchReplaceFn):
self._name = name
self._fn = fn
def name(self) -> str:
return self._name
def run(
self,
rows: Sequence[llmfn_output_row.LLMFnOutputRowView],
) -> Sequence[llmfn_output_row.LLMFnOutputRow]:
new_values = self._fn(rows)
if len(new_values) != len(rows):
raise llmfn_post_process.PostProcessExecutionError(
'Error executing "{}": returned length ({}) != number of input rows'
" ({})".format(self._name, len(new_values), len(rows))
)
new_rows: list[llmfn_output_row.LLMFnOutputRow] = []
for new_value, row in zip(new_values, rows):
new_row = _convert_view_to_output_row(row)
new_row.set_result_value(value=new_value)
new_rows.append(new_row)
return new_rows
class LLMCompareFnPostProcessCommand(LLMFnPostProcessCommand):
"""Post-processing commands for LLMCompareFunction."""
@abc.abstractmethod
def run(
self,
rows: Sequence[
tuple[
llmfn_output_row.LLMFnOutputRowView,
llmfn_output_row.LLMFnOutputRowView,
llmfn_output_row.LLMFnOutputRowView,
]
],
) -> Sequence[llmfn_output_row.LLMFnOutputRow]:
"""Processes a batch of left- and right-hand side results.
Args:
rows: The rows in a batch. Each row is a three-tuple containing: - The
left-hand side results, - The right-hand side results, and - The current
combined results
Returns:
A new set of rows that should replace the combined results.
"""
class LLMCompareFnPostProcessAddCommand(LLMCompareFnPostProcessCommand):
"""A command that adds each row with a new column.
This does not change the value of the results cell.
"""
def __init__(
self,
name: str,
fn: llmfn_post_process.LLMCompareFnPostProcessBatchAddFn,
):
self._name = name
self._fn = fn
def name(self) -> str:
return self._name
def run(
self,
rows: Sequence[
tuple[
llmfn_output_row.LLMFnOutputRowView,
llmfn_output_row.LLMFnOutputRowView,
llmfn_output_row.LLMFnOutputRowView,
]
],
) -> Sequence[llmfn_output_row.LLMFnOutputRow]:
new_values = self._fn([(lhs, rhs) for lhs, rhs, _ in rows])
if len(new_values) != len(rows):
raise llmfn_post_process.PostProcessExecutionError(
'Error executing "{}": returned length ({}) != number of input rows'
" ({})".format(self._name, len(new_values), len(rows))
)
new_rows: list[llmfn_output_row.LLMFnOutputRow] = []
for new_value, row in zip(new_values, [combined for _, _, combined in rows]):
new_row = _convert_view_to_output_row(row)
new_row.add(key=self._name, value=new_value)
new_rows.append(new_row)
return new_rows