google/generativeai/answer.py (210 lines of code) (raw):
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import dataclasses
from collections.abc import Iterable
import itertools
from typing import Any, Iterable, Union, Mapping, Optional
from typing_extensions import TypedDict
import google.ai.generativelanguage as glm
from google.generativeai import protos
from google.generativeai.client import (
get_default_generative_client,
get_default_generative_async_client,
)
from google.generativeai.types import model_types
from google.generativeai.types import helper_types
from google.generativeai.types import safety_types
from google.generativeai.types import content_types
from google.generativeai.types import retriever_types
from google.generativeai.types.retriever_types import MetadataFilter
DEFAULT_ANSWER_MODEL = "models/aqa"
AnswerStyle = protos.GenerateAnswerRequest.AnswerStyle
AnswerStyleOptions = Union[int, str, AnswerStyle]
_ANSWER_STYLES: dict[AnswerStyleOptions, AnswerStyle] = {
AnswerStyle.ANSWER_STYLE_UNSPECIFIED: AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
0: AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
"answer_style_unspecified": AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
"unspecified": AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
AnswerStyle.ABSTRACTIVE: AnswerStyle.ABSTRACTIVE,
1: AnswerStyle.ABSTRACTIVE,
"answer_style_abstractive": AnswerStyle.ABSTRACTIVE,
"abstractive": AnswerStyle.ABSTRACTIVE,
AnswerStyle.EXTRACTIVE: AnswerStyle.EXTRACTIVE,
2: AnswerStyle.EXTRACTIVE,
"answer_style_extractive": AnswerStyle.EXTRACTIVE,
"extractive": AnswerStyle.EXTRACTIVE,
AnswerStyle.VERBOSE: AnswerStyle.VERBOSE,
3: AnswerStyle.VERBOSE,
"answer_style_verbose": AnswerStyle.VERBOSE,
"verbose": AnswerStyle.VERBOSE,
}
def to_answer_style(x: AnswerStyleOptions) -> AnswerStyle:
if isinstance(x, str):
x = x.lower()
return _ANSWER_STYLES[x]
GroundingPassageOptions = (
Union[
protos.GroundingPassage, tuple[str, content_types.ContentType], content_types.ContentType
],
)
GroundingPassagesOptions = Union[
protos.GroundingPassages,
Iterable[GroundingPassageOptions],
Mapping[str, content_types.ContentType],
]
def _make_grounding_passages(source: GroundingPassagesOptions) -> protos.GroundingPassages:
"""
Converts the `source` into a `protos.GroundingPassage`. A `GroundingPassages` contains a list of
`protos.GroundingPassage` objects, which each contain a `protos.Content` and a string `id`.
Args:
source: `Content` or a `GroundingPassagesOptions` that will be converted to protos.GroundingPassages.
Return:
`protos.GroundingPassages` to be passed into `protos.GenerateAnswer`.
"""
if isinstance(source, protos.GroundingPassages):
return source
if not isinstance(source, Iterable):
raise TypeError(
f"Invalid input: The 'source' argument must be an instance of 'GroundingPassagesOptions'. Received a '{type(source).__name__}' object instead."
)
passages = []
if isinstance(source, Mapping):
source = source.items()
for n, data in enumerate(source):
if isinstance(data, protos.GroundingPassage):
passages.append(data)
elif isinstance(data, tuple):
id, content = data # tuple must have exactly 2 items.
passages.append({"id": id, "content": content_types.to_content(content)})
else:
passages.append({"id": str(n), "content": content_types.to_content(data)})
return protos.GroundingPassages(passages=passages)
SourceNameType = Union[
str, retriever_types.Corpus, protos.Corpus, retriever_types.Document, protos.Document
]
class SemanticRetrieverConfigDict(TypedDict):
source: SourceNameType
query: content_types.ContentsType
metadata_filter: Optional[Iterable[MetadataFilter]]
max_chunks_count: Optional[int]
minimum_relevance_score: Optional[float]
SemanticRetrieverConfigOptions = Union[
SourceNameType,
SemanticRetrieverConfigDict,
protos.SemanticRetrieverConfig,
]
def _maybe_get_source_name(source) -> str | None:
if isinstance(source, str):
return source
elif isinstance(
source, (retriever_types.Corpus, protos.Corpus, retriever_types.Document, protos.Document)
):
return source.name
else:
return None
def _make_semantic_retriever_config(
source: SemanticRetrieverConfigOptions,
query: content_types.ContentsType,
) -> protos.SemanticRetrieverConfig:
if isinstance(source, protos.SemanticRetrieverConfig):
return source
name = _maybe_get_source_name(source)
if name is not None:
source = {"source": name}
elif isinstance(source, dict):
source["source"] = _maybe_get_source_name(source["source"])
else:
raise TypeError(
f"Invalid input: Failed to create a 'protos.SemanticRetrieverConfig' from the provided source. "
f"Received type: {type(source).__name__}, "
f"Received value: {source}"
)
if source["query"] is None:
source["query"] = query
elif isinstance(source["query"], str):
source["query"] = content_types.to_content(source["query"])
return protos.SemanticRetrieverConfig(source)
def _make_generate_answer_request(
*,
model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL,
contents: content_types.ContentsType,
inline_passages: GroundingPassagesOptions | None = None,
semantic_retriever: SemanticRetrieverConfigOptions | None = None,
answer_style: AnswerStyle | None = None,
safety_settings: safety_types.SafetySettingOptions | None = None,
temperature: float | None = None,
) -> protos.GenerateAnswerRequest:
"""
constructs a protos.GenerateAnswerRequest object by organizing the input parameters for the API call to generate a grounded answer from the model.
Args:
model: Name of the model used to generate the grounded response.
contents: Content of the current conversation with the model. For single-turn query, this is a
single question to answer. For multi-turn queries, this is a repeated field that contains
conversation history and the last `Content` in the list containing the question.
inline_passages: Grounding passages (a list of `Content`-like objects or `(id, content)` pairs,
or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retriever`,
one must be set, but not both.
semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with
`inline_passages`, one must be set, but not both.
answer_style: Style for grounded answers.
safety_settings: Safety settings for generated output.
temperature: The temperature for randomness in the output.
Returns:
Call for protos.GenerateAnswerRequest().
"""
model = model_types.make_model_name(model)
contents = content_types.to_contents(contents)
if safety_settings:
safety_settings = safety_types.normalize_safety_settings(safety_settings)
if inline_passages is not None and semantic_retriever is not None:
raise ValueError(
f"Invalid configuration: Please set either 'inline_passages' or 'semantic_retriever_config', but not both. "
f"Received for inline_passages: {inline_passages}, and for semantic_retriever: {semantic_retriever}."
)
elif inline_passages is not None:
inline_passages = _make_grounding_passages(inline_passages)
elif semantic_retriever is not None:
semantic_retriever = _make_semantic_retriever_config(semantic_retriever, contents[-1])
else:
raise TypeError(
f"Invalid configuration: Either 'inline_passages' or 'semantic_retriever_config' must be provided, but currently both are 'None'. "
f"Received for inline_passages: {inline_passages}, and for semantic_retriever: {semantic_retriever}."
)
if answer_style:
answer_style = to_answer_style(answer_style)
return protos.GenerateAnswerRequest(
model=model,
contents=contents,
inline_passages=inline_passages,
semantic_retriever=semantic_retriever,
safety_settings=safety_settings,
temperature=temperature,
answer_style=answer_style,
)
def generate_answer(
*,
model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL,
contents: content_types.ContentsType,
inline_passages: GroundingPassagesOptions | None = None,
semantic_retriever: SemanticRetrieverConfigOptions | None = None,
answer_style: AnswerStyle | None = None,
safety_settings: safety_types.SafetySettingOptions | None = None,
temperature: float | None = None,
client: glm.GenerativeServiceClient | None = None,
request_options: helper_types.RequestOptionsType | None = None,
):
"""Calls the GenerateAnswer API and returns a `types.Answer` containing the response.
You can pass a literal list of text chunks:
>>> from google.generativeai import answer
>>> answer.generate_answer(
... content=question,
... inline_passages=splitter.split(document)
... )
Or pass a reference to a retreiver Document or Corpus:
>>> from google.generativeai import answer
>>> from google.generativeai import retriever
>>> my_corpus = retriever.get_corpus('my_corpus')
>>> genai.generate_answer(
... content=question,
... semantic_retriever=my_corpus
... )
Args:
model: Which model to call, as a string or a `types.Model`.
contents: The question to be answered by the model, grounded in the
provided source.
inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs,
or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retriever`,
one must be set, but not both.
semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with
`inline_passages`, one must be set, but not both.
answer_style: Style in which the grounded answer should be returned.
safety_settings: Safety settings for generated output. Defaults to None.
temperature: Controls the randomness of the output.
client: If you're not relying on a default client, you pass a `glm.GenerativeServiceClient` instead.
request_options: Options for the request.
Returns:
A `types.Answer` containing the model's text answer response.
"""
if request_options is None:
request_options = {}
if client is None:
client = get_default_generative_client()
request = _make_generate_answer_request(
model=model,
contents=contents,
inline_passages=inline_passages,
semantic_retriever=semantic_retriever,
safety_settings=safety_settings,
temperature=temperature,
answer_style=answer_style,
)
response = client.generate_answer(request, **request_options)
return response
async def generate_answer_async(
*,
model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL,
contents: content_types.ContentsType,
inline_passages: GroundingPassagesOptions | None = None,
semantic_retriever: SemanticRetrieverConfigOptions | None = None,
answer_style: AnswerStyle | None = None,
safety_settings: safety_types.SafetySettingOptions | None = None,
temperature: float | None = None,
client: glm.GenerativeServiceClient | None = None,
request_options: helper_types.RequestOptionsType | None = None,
):
"""
Calls the API and returns a `types.Answer` containing the answer.
Args:
model: Which model to call, as a string or a `types.Model`.
contents: The question to be answered by the model, grounded in the
provided source.
inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs,
or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retriever`,
one must be set, but not both.
semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with
`inline_passages`, one must be set, but not both.
answer_style: Style in which the grounded answer should be returned.
safety_settings: Safety settings for generated output. Defaults to None.
temperature: Controls the randomness of the output.
client: If you're not relying on a default client, you pass a `glm.GenerativeServiceClient` instead.
Returns:
A `types.Answer` containing the model's text answer response.
"""
if request_options is None:
request_options = {}
if client is None:
client = get_default_generative_async_client()
request = _make_generate_answer_request(
model=model,
contents=contents,
inline_passages=inline_passages,
semantic_retriever=semantic_retriever,
safety_settings=safety_settings,
temperature=temperature,
answer_style=answer_style,
)
response = await client.generate_answer(request, **request_options)
return response