google/generativeai/notebook/lib/llmfn_output_row.py (73 lines of code) (raw):

# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """LLMFnOutputRow.""" from __future__ import annotations import abc from typing import Any, Iterator, Mapping # The type of value stored in a cell. _CELLVALUETYPE = Any def _get_name_of_type(x: type[Any]) -> str: if hasattr(x, "__name__"): return x.__name__ return str(x) def _validate_is_result_type(value: Any, result_type: type[Any]) -> None: if result_type == Any: return if not isinstance(value, result_type): raise ValueError( 'Value of last entry must be of type "{}", got "{}"'.format( _get_name_of_type(result_type), _get_name_of_type(type(value)), ) ) class LLMFnOutputRowView(Mapping[str, _CELLVALUETYPE], metaclass=abc.ABCMeta): """Immutable view of LLMFnOutputRow.""" # Additional methods (not required by Mapping[str, _CELLVALUETYPE]) @abc.abstractmethod def __contains__(self, k: str) -> bool: """For expressions like: x in this_instance.""" @abc.abstractmethod def __str__(self) -> str: """For expressions like: str(this_instance).""" # Own methods. @abc.abstractmethod def result_type(self) -> type[Any]: """Returns the type enforced for the result cell.""" @abc.abstractmethod def result_value(self) -> Any: """Get the value of the result cell.""" @abc.abstractmethod def result_key(self) -> str: """Get the key of the result cell.""" class LLMFnOutputRow(LLMFnOutputRowView): """Container that represents a single row in a table of outputs. We represent outputs as a table. This class represents a single row in the table like a dictionary, where the key is the column name and the value is the cell value. A single cell is designated the "result". This contains the output of the LLM model after running any post-processing functions specified by the user. In addition to behaving like a dictionary, this class provides additional methods, including: - Getting the value of the "result" cell - Setting the value (and optionally the key) of the "result" cell. - Add a new non-result cell Notes: As an implementation detail, the result-cell is always kept as the rightmost cell. """ def __init__(self, data: Mapping[str, _CELLVALUETYPE], result_type: type[Any]): """Constructor. Args: data: The initial value of the row. The last entry will be treated as the result. Cannot be empty. The value of the last entry must be `str`. result_type: The type of the result cell. This will be enforced at runtime. """ self._data: dict[str, _CELLVALUETYPE] = dict(data) if not self._data: raise ValueError("Must provide non-empty data") self._result_type = result_type result_value = list(self._data.values())[-1] _validate_is_result_type(result_value, self._result_type) # Methods needed for Mapping[str, _CELLVALUETYPE]: def __iter__(self) -> Iterator[str]: return self._data.__iter__() def __len__(self) -> int: return self._data.__len__() def __getitem__(self, k: str) -> _CELLVALUETYPE: return self._data.__getitem__(k) # Additional methods for LLMFnOutputRowView. def __contains__(self, k: str) -> bool: return self._data.__contains__(k) def __str__(self) -> str: return "LLMFnOutputRow: {}".format(self._data.__str__()) def result_type(self) -> type[Any]: return self._result_type def result_value(self) -> Any: return self._data[self.result_key()] def result_key(self) -> str: # Our invariant is that the result-cell is always the rightmost cell. return list(self._data.keys())[-1] # Mutable methods. def set_result_value(self, value: Any, key: str | None = None) -> None: """Set the value of the result cell. Sets the value (and optionally the key) of the result cell. Args: value: The value to set the result cell today. key: Optionally change the key as well. """ _validate_is_result_type(value, self._result_type) current_key = self.result_key() if key is None or key == current_key: self._data[current_key] = value return del self._data[current_key] self._data[key] = value def add(self, key: str, value: _CELLVALUETYPE) -> None: """Add a non-result cell. Adds a new non-result cell. This does not affect the result cell. Args: key: The key of the new cell to add. value: The value of the new cell to add. """ # Handle collisions with `key`. if key in self._data: idx = 1 candidate_key = key while candidate_key in self._data: candidate_key = "{}_{}".format(key, idx) idx = idx + 1 key = candidate_key # Insert the new key/value into the second rightmost position to keep # the result cell as the rightmost cell. result_key = self.result_key() result_value = self._data.pop(result_key) self._data[key] = value self._data[result_key] = result_value