google/generativeai/notebook/text_model.py (44 lines of code) (raw):
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model that uses the Text service."""
from __future__ import annotations
from google.api_core import retry
import google.generativeai as genai
from google.generativeai.types import generation_types
from google.generativeai.notebook.lib import model as model_lib
_DEFAULT_MODEL = "models/gemini-1.5-flash"
class TextModel(model_lib.AbstractModel):
"""Concrete model that uses the generate_content service."""
def _generate_text(
self,
prompt: str,
model: str | None = None,
temperature: float | None = None,
candidate_count: int | None = None,
) -> generation_types.GenerateContentResponse:
gen_config = {}
if temperature is not None:
gen_config["temperature"] = temperature
if candidate_count is not None:
gen_config["candidate_count"] = candidate_count
model_name = model or _DEFAULT_MODEL
gen_model = genai.GenerativeModel(model_name=model_name)
gc = genai.types.generation_types.GenerationConfig(**gen_config)
return gen_model.generate_content(prompt, generation_config=gc)
def call_model(
self,
model_input: str,
model_args: model_lib.ModelArguments | None = None,
) -> model_lib.ModelResults:
if model_args is None:
model_args = model_lib.ModelArguments()
# Wrap the generation function here, rather than decorate, so that it
# applies to any overridden calls too.
retryable_fn = retry.Retry(retry.if_transient_error)(self._generate_text)
response = retryable_fn(
prompt=model_input,
model=model_args.model,
temperature=model_args.temperature,
candidate_count=model_args.candidate_count,
)
text_outputs = []
for c in response.candidates:
text_outputs.append("".join(p.text for p in c.content.parts))
return model_lib.ModelResults(
model_input=model_input,
text_results=text_outputs,
)