google/generativeai/notebook/post_process_utils.py (75 lines of code) (raw):

# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for working with post-processing tokens.""" from __future__ import annotations import abc from typing import Any, Callable, Sequence from google.generativeai.notebook import py_utils from google.generativeai.notebook.lib import llm_function from google.generativeai.notebook.lib import llmfn_output_row from google.generativeai.notebook.lib import llmfn_post_process class PostProcessParseError(RuntimeError): """An error parsing the post-processing tokens.""" class ParsedPostProcessExpr(abc.ABC): """A post-processing expression parsed from the command line.""" @abc.abstractmethod def name(self) -> str: """Returns the name of this expression.""" @abc.abstractmethod def add_to_llm_function(self, llm_fn: llm_function.LLMFunction) -> llm_function.LLMFunction: """Adds this parsed expression to `llm_fn` as a post-processing command.""" class _ParsedPostProcessAddExpr( ParsedPostProcessExpr, llmfn_post_process.LLMFnPostProcessBatchAddFn ): """An expression that returns the value of a new column to add to a row.""" def __init__(self, name: str, fn: Callable[[str], Any]): """Constructor. Args: name: The name of the expression. The name of the new column will be derived from this. fn: A function that takes the result of a row and returns a new value to add as a new column in the row. """ self._name = name self._fn = fn def name(self) -> str: return self._name def __call__(self, rows: Sequence[llmfn_output_row.LLMFnOutputRowView]) -> Sequence[Any]: return [self._fn(row.result_value()) for row in rows] def add_to_llm_function(self, llm_fn: llm_function.LLMFunction) -> llm_function.LLMFunction: return llm_fn.add_post_process_add_fn(name=self._name, fn=self) class _ParsedPostProcessReplaceExpr( ParsedPostProcessExpr, llmfn_post_process.LLMFnPostProcessBatchReplaceFn ): """An expression that returns the new result value for a row.""" def __init__(self, name: str, fn: Callable[[str], str]): """Constructor. Args: name: The name of the expression. fn: A function that takes the result of a row and returns the new result. """ self._name = name self._fn = fn def name(self) -> str: return self._name def __call__(self, rows: Sequence[llmfn_output_row.LLMFnOutputRowView]) -> Sequence[str]: return [self._fn(row.result_value()) for row in rows] def add_to_llm_function(self, llm_fn: llm_function.LLMFunction) -> llm_function.LLMFunction: return llm_fn.add_post_process_replace_fn(name=self._name, fn=self) # Decorator functions. def post_process_add_fn(fn: Callable[[str], Any]): return _ParsedPostProcessAddExpr(name=fn.__name__, fn=fn) def post_process_replace_fn(fn: Callable[[str], str]): return _ParsedPostProcessReplaceExpr(name=fn.__name__, fn=fn) def validate_one_post_processing_expression( tokens: Sequence[str], ) -> None: if not tokens: raise PostProcessParseError("Cannot have empty post-processing expression") if len(tokens) > 1: raise PostProcessParseError("Post-processing expression should be a single token") def _resolve_one_post_processing_expression( tokens: Sequence[str], ) -> tuple[str, Any]: """Returns name and the resolved expression.""" validate_one_post_processing_expression(tokens) token_parts = tokens[0].split(".") current_module = py_utils.get_main_module() for part_num, part in enumerate(token_parts): current_module_vars = vars(current_module) if part not in current_module_vars: raise PostProcessParseError( 'Unable to resolve "{}"'.format(".".join(token_parts[: part_num + 1])) ) current_module = current_module_vars[part] return (" ".join(tokens), current_module) def resolve_post_processing_tokens( tokens: Sequence[Sequence[str]], ) -> Sequence[ParsedPostProcessExpr]: """Resolves post-processing tokens into ParsedPostProcessExprs. E.g. Given [["add_length"], ["to_upper"]] as input, this function will return a sequence of ParsedPostProcessExprs that will execute add_length() and to_upper() on each entry of the LLM output as post-processing operations. Raises: PostProcessParseError: An error parsing or resolving the tokens. Args: tokens: A sequence of post-processing tokens after splitting. Returns: A sequence of ParsedPostProcessExprs. """ results: list[ParsedPostProcessExpr] = [] for expression in tokens: expr_name, expr_value = _resolve_one_post_processing_expression(expression) if isinstance(expr_value, ParsedPostProcessExpr): results.append(expr_value) elif isinstance(expr_value, Callable): # By default, assume that an undecorated function is an "add" function. results.append(_ParsedPostProcessAddExpr(name=expr_name, fn=expr_value)) else: raise PostProcessParseError("{} is not callable".format(expr_name)) return results