google/generativeai/notebook/lib/llmfn_input_utils.py (37 lines of code) (raw):

# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for handling input variables.""" from __future__ import annotations from typing import Any, Mapping, Sequence, Union from google.generativeai.notebook.lib import llmfn_inputs_source _NormalizedInputsList = llmfn_inputs_source.NormalizedInputsList _ColumnOrderValuesList = Mapping[str, Sequence[str]] LLMFunctionInputs = Union[_ColumnOrderValuesList, llmfn_inputs_source.LLMFnInputsSource] def _is_column_order_values_list(inputs: Any) -> bool: """See if inputs is of the form: {"key1": ["val1", "val2", ...]}. This is similar to the format produced by: pandas.DataFrame.to_dict(orient="list") Args: inputs: The inputs passed into an LLMFunction. Returns: Whether `inputs` is a column-ordered list of values. """ if not isinstance(inputs, Mapping): return False for x in inputs.values(): if not isinstance(x, Sequence): return False # Strings and bytes are also considered Sequences but we disallow them # here because the values contained in their Sequences are single # characters rather than words. if isinstance(x, str) or isinstance(x, bytes): return False return True # TODO(b/273688393): Perform stricter validation on `inputs`. def _normalize_column_order_values_list( inputs: _ColumnOrderValuesList, ) -> _NormalizedInputsList: """Transforms prompt inputs into a list of dictionaries.""" return_list: list[dict[str, str]] = [] keys = list(inputs.keys()) if keys: first_key = keys[0] for row_num in range(len(inputs[first_key])): row_dict = {} return_list.append(row_dict) for key in keys: row_dict[key] = inputs[key][row_num] return return_list def to_normalized_inputs(inputs: LLMFunctionInputs) -> _NormalizedInputsList: """Handles the different types of `inputs` and returns a normalized form.""" normalized_inputs: list[Mapping[str, str]] = [] if isinstance(inputs, llmfn_inputs_source.LLMFnInputsSource): normalized_inputs.extend(inputs.to_normalized_inputs()) elif _is_column_order_values_list(inputs): normalized_inputs.extend(_normalize_column_order_values_list(inputs)) else: raise ValueError("Unsupported input type {!r}".format(inputs)) return normalized_inputs