google/generativeai/notebook/lib/llmfn_output_row.py (73 lines of code) (raw):
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LLMFnOutputRow."""
from __future__ import annotations
import abc
from typing import Any, Iterator, Mapping
# The type of value stored in a cell.
_CELLVALUETYPE = Any
def _get_name_of_type(x: type[Any]) -> str:
if hasattr(x, "__name__"):
return x.__name__
return str(x)
def _validate_is_result_type(value: Any, result_type: type[Any]) -> None:
if result_type == Any:
return
if not isinstance(value, result_type):
raise ValueError(
'Value of last entry must be of type "{}", got "{}"'.format(
_get_name_of_type(result_type),
_get_name_of_type(type(value)),
)
)
class LLMFnOutputRowView(Mapping[str, _CELLVALUETYPE], metaclass=abc.ABCMeta):
"""Immutable view of LLMFnOutputRow."""
# Additional methods (not required by Mapping[str, _CELLVALUETYPE])
@abc.abstractmethod
def __contains__(self, k: str) -> bool:
"""For expressions like: x in this_instance."""
@abc.abstractmethod
def __str__(self) -> str:
"""For expressions like: str(this_instance)."""
# Own methods.
@abc.abstractmethod
def result_type(self) -> type[Any]:
"""Returns the type enforced for the result cell."""
@abc.abstractmethod
def result_value(self) -> Any:
"""Get the value of the result cell."""
@abc.abstractmethod
def result_key(self) -> str:
"""Get the key of the result cell."""
class LLMFnOutputRow(LLMFnOutputRowView):
"""Container that represents a single row in a table of outputs.
We represent outputs as a table. This class represents a single row in the
table like a dictionary, where the key is the column name and the value is the
cell value.
A single cell is designated the "result". This contains the output of the LLM
model after running any post-processing functions specified by the user.
In addition to behaving like a dictionary, this class provides additional
methods, including:
- Getting the value of the "result" cell
- Setting the value (and optionally the key) of the "result" cell.
- Add a new non-result cell
Notes: As an implementation detail, the result-cell is always kept as the
rightmost cell.
"""
def __init__(self, data: Mapping[str, _CELLVALUETYPE], result_type: type[Any]):
"""Constructor.
Args:
data: The initial value of the row. The last entry will be treated as the
result. Cannot be empty. The value of the last entry must be `str`.
result_type: The type of the result cell. This will be enforced at
runtime.
"""
self._data: dict[str, _CELLVALUETYPE] = dict(data)
if not self._data:
raise ValueError("Must provide non-empty data")
self._result_type = result_type
result_value = list(self._data.values())[-1]
_validate_is_result_type(result_value, self._result_type)
# Methods needed for Mapping[str, _CELLVALUETYPE]:
def __iter__(self) -> Iterator[str]:
return self._data.__iter__()
def __len__(self) -> int:
return self._data.__len__()
def __getitem__(self, k: str) -> _CELLVALUETYPE:
return self._data.__getitem__(k)
# Additional methods for LLMFnOutputRowView.
def __contains__(self, k: str) -> bool:
return self._data.__contains__(k)
def __str__(self) -> str:
return "LLMFnOutputRow: {}".format(self._data.__str__())
def result_type(self) -> type[Any]:
return self._result_type
def result_value(self) -> Any:
return self._data[self.result_key()]
def result_key(self) -> str:
# Our invariant is that the result-cell is always the rightmost cell.
return list(self._data.keys())[-1]
# Mutable methods.
def set_result_value(self, value: Any, key: str | None = None) -> None:
"""Set the value of the result cell.
Sets the value (and optionally the key) of the result cell.
Args:
value: The value to set the result cell today.
key: Optionally change the key as well.
"""
_validate_is_result_type(value, self._result_type)
current_key = self.result_key()
if key is None or key == current_key:
self._data[current_key] = value
return
del self._data[current_key]
self._data[key] = value
def add(self, key: str, value: _CELLVALUETYPE) -> None:
"""Add a non-result cell.
Adds a new non-result cell. This does not affect the result cell.
Args:
key: The key of the new cell to add.
value: The value of the new cell to add.
"""
# Handle collisions with `key`.
if key in self._data:
idx = 1
candidate_key = key
while candidate_key in self._data:
candidate_key = "{}_{}".format(key, idx)
idx = idx + 1
key = candidate_key
# Insert the new key/value into the second rightmost position to keep
# the result cell as the rightmost cell.
result_key = self.result_key()
result_value = self._data.pop(result_key)
self._data[key] = value
self._data[result_key] = result_value