google/generativeai/notebook/text_model.py (44 lines of code) (raw):

# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Model that uses the Text service.""" from __future__ import annotations from google.api_core import retry import google.generativeai as genai from google.generativeai.types import generation_types from google.generativeai.notebook.lib import model as model_lib _DEFAULT_MODEL = "models/gemini-1.5-flash" class TextModel(model_lib.AbstractModel): """Concrete model that uses the generate_content service.""" def _generate_text( self, prompt: str, model: str | None = None, temperature: float | None = None, candidate_count: int | None = None, ) -> generation_types.GenerateContentResponse: gen_config = {} if temperature is not None: gen_config["temperature"] = temperature if candidate_count is not None: gen_config["candidate_count"] = candidate_count model_name = model or _DEFAULT_MODEL gen_model = genai.GenerativeModel(model_name=model_name) gc = genai.types.generation_types.GenerationConfig(**gen_config) return gen_model.generate_content(prompt, generation_config=gc) def call_model( self, model_input: str, model_args: model_lib.ModelArguments | None = None, ) -> model_lib.ModelResults: if model_args is None: model_args = model_lib.ModelArguments() # Wrap the generation function here, rather than decorate, so that it # applies to any overridden calls too. retryable_fn = retry.Retry(retry.if_transient_error)(self._generate_text) response = retryable_fn( prompt=model_input, model=model_args.model, temperature=model_args.temperature, candidate_count=model_args.candidate_count, ) text_outputs = [] for c in response.candidates: text_outputs.append("".join(p.text for p in c.content.parts)) return model_lib.ModelResults( model_input=model_input, text_results=text_outputs, )