google/generativeai/notebook/lib/llmfn_outputs.py (112 lines of code) (raw):
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Output of LLMFunction."""
from __future__ import annotations
import abc
import dataclasses
from typing import (
overload,
Any,
Callable,
Iterable,
Iterator,
Mapping,
Sequence,
)
from google.generativeai.notebook.lib import llmfn_output_row
from google.generativeai.notebook.lib import model as model_lib
import pandas
class ColumnNames:
"""Names of columns that are used to represent output."""
PROMPT_NUM = "Prompt Num"
INPUT_NUM = "Input Num"
RESULT_NUM = "Result Num"
# In the code we refer to "model_input" as the full keyword-substituted prompt
# and "prompt" as the template with placeholders.
# When displaying the results however we use "prompt" since "model_input" is
# an internal name.
MODEL_INPUT = "Prompt"
PROMPT_VARS = "Prompt vars"
TEXT_RESULT = "text_result"
@dataclasses.dataclass
class LLMFnOutputEntry:
"""The output of a single model input from LLMFunction.
A model input is a prompt where the keyword placeholders have been
substituted (by `prompt_vars`).
E.g. If we have:
prompt: "the opposite of {word} is"
prompt_vars: {"word", "hot"}
Then we will have the following model input:
model_input: "the opposite of hot is"
Note: The model may produce one-or-more results for a given model_input.
This is represented by the sequence `output_rows`.
"""
prompt_num: int
input_num: int
prompt_vars: Mapping[str, str]
output_rows: Sequence[llmfn_output_row.LLMFnOutputRow]
prompt: str | None = None
model_input: str | None = None
model_results: model_lib.ModelResults | None = None
def _has_model_input_field(outputs: Iterable[LLMFnOutputEntry]):
for entry in outputs:
if entry.model_input is not None:
return True
return False
class LLMFnOutputsBase(Sequence[LLMFnOutputEntry]):
"""Parent class for LLMFnOutputs.
This class exists mainly to avoid a circular dependency between LLMFnOutputs
and LLMFnOutputsSink. Most users should use LLMFnOutputs directly instead.
"""
def __init__(
self,
outputs: Iterable[LLMFnOutputEntry] | None = None,
):
"""Constructor.
Args:
outputs: The contents of this LLMFnOutputs instance.
"""
self._outputs: list[LLMFnOutputEntry] = list(outputs) if outputs is not None else []
# Needed for Iterable[LLMFnOutputEntry].
def __iter__(self) -> Iterator[LLMFnOutputEntry]:
return self._outputs.__iter__()
# Needed for Sequence[LLMFnOutputEntry].
def __len__(self) -> int:
return self._outputs.__len__()
# Needed for Sequence[LLMFnOutputEntry].
@overload
def __getitem__(self, x: int) -> LLMFnOutputEntry: ...
@overload
def __getitem__(self, x: slice) -> Sequence[LLMFnOutputEntry]: ...
def __getitem__(self, x: int | slice) -> LLMFnOutputEntry | Sequence[LLMFnOutputEntry]:
return self._outputs.__getitem__(x)
# Convenience methods.
def __bool__(self) -> bool:
return bool(self._outputs)
def __str__(self) -> str:
return self.as_pandas_dataframe().__str__()
# Own methods
def as_dict(self) -> Mapping[str, Sequence[Any]]:
"""Formats returned results as dictionary."""
# `data` is a table in column order, with the columns listed from left to
# right.
data = {
ColumnNames.PROMPT_NUM: [],
ColumnNames.INPUT_NUM: [],
# RESULT_NUM is special: each LLMFnOutputRow in self._outputs is
# expected to have a RESULT_NUM key.
ColumnNames.RESULT_NUM: [],
}
if _has_model_input_field(self._outputs):
data[ColumnNames.MODEL_INPUT] = []
if not self._outputs:
return data
# Add column names of added data.
# The last key in LLMFnOutputRow is special as it is considered
# the result. To preserve order in the (unlikely) event of inconsistent
# keys across rows, we first add all-but-the-last key to `total_keys_set`,
# then the last key.
# Note: `total_keys_set` is a Python dictionary instead of a Python set
# because Python dictionaries preserve the order in which entries are
# added, whereas Python sets do not.
total_keys_set: dict[str, None] = {k: None for k in data.keys()}
for output in self._outputs:
for result in output.output_rows:
for key in list(result.keys())[:-1]:
total_keys_set[key] = None
for output in self._outputs:
for result in output.output_rows:
total_keys_set[list(result.keys())[-1]] = None
# `data` represents the table as a dictionary of:
# column names -> list of values
for key in total_keys_set:
data[key] = []
next_num_rows = 1
for output in self._outputs:
for result in output.output_rows:
data[ColumnNames.PROMPT_NUM].append(output.prompt_num)
data[ColumnNames.INPUT_NUM].append(output.input_num)
if ColumnNames.MODEL_INPUT in data:
data[ColumnNames.MODEL_INPUT].append(output.model_input)
for key, value in result.items():
data[key].append(value)
# Look for empty cells and pad them with None.
for column in data.values():
if len(column) < next_num_rows:
column.append(None)
next_num_rows += 1
return data
def as_pandas_dataframe(self) -> pandas.DataFrame:
return pandas.DataFrame(self.as_dict())
class LLMFnOutputsSink(abc.ABC):
"""Abstract class representing an exporter for the output of LLMFunction.
This class could be extended to write to external documents, such as
Google Sheets.
"""
def write_outputs(self, outputs: LLMFnOutputsBase) -> None:
"""Writes `outputs` to some destination."""
class LLMFnOutputs(LLMFnOutputsBase):
"""A sequence of LLMFnOutputEntry instances.
Notes:
- Each LLMFnOutputEntry represents the results of running one model
input (see documentation for LLMFnOutputEntry for what "model input"
means.)
- A single model input may produce more-than-one text results.
"""
def __init__(
self,
outputs: Iterable[LLMFnOutputEntry] | None = None,
ipython_display_fn: Callable[[LLMFnOutputs], None] | None = None,
):
"""Constructor.
Args:
outputs: The contents of this LLMFnOutputs instance.
ipython_display_fn: An optional function for pretty-printing this instance
when it is the output of a cell in a notebook. If this argument is not
None, the _ipython_display_ method will be defined which will in turn
invoke this function.
"""
super().__init__(outputs=outputs)
if ipython_display_fn:
self._ipython_display_fn = ipython_display_fn
# We define the _ipython_display_ method only when `ipython_display_fn`
# is set. This lets us fall back to a default implementation defined by
# the notebook when `ipython_display_fn` is not set, instead of having to
# provide our own default implementation.
setattr(
self,
"_ipython_display_",
getattr(self, "_ipython_display_impl"),
)
def _ipython_display_impl(self):
"""Actual implementation of _ipython_display_.
This method should only be used invoked if self._ipython_display_fn is set.
"""
self._ipython_display_fn(self)
def export(self, sink: LLMFnOutputsSink) -> None:
"""Export contents to `sink`."""
sink.write_outputs(self)