google/generativeai/embedding.py (228 lines of code) (raw):
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import itertools
from typing import Any, Iterable, overload, TypeVar, Union, Mapping
import google.ai.generativelanguage as glm
from google.generativeai import protos
from google.generativeai.client import get_default_generative_client
from google.generativeai.client import get_default_generative_async_client
from google.generativeai.types import helper_types
from google.generativeai.types import model_types
from google.generativeai.types import text_types
from google.generativeai.types import content_types
DEFAULT_EMB_MODEL = "models/embedding-001"
EMBEDDING_MAX_BATCH_SIZE = 100
EmbeddingTaskType = protos.TaskType
EmbeddingTaskTypeOptions = Union[int, str, EmbeddingTaskType]
_EMBEDDING_TASK_TYPE: dict[EmbeddingTaskTypeOptions, EmbeddingTaskType] = {
EmbeddingTaskType.TASK_TYPE_UNSPECIFIED: EmbeddingTaskType.TASK_TYPE_UNSPECIFIED,
0: EmbeddingTaskType.TASK_TYPE_UNSPECIFIED,
"task_type_unspecified": EmbeddingTaskType.TASK_TYPE_UNSPECIFIED,
"unspecified": EmbeddingTaskType.TASK_TYPE_UNSPECIFIED,
EmbeddingTaskType.RETRIEVAL_QUERY: EmbeddingTaskType.RETRIEVAL_QUERY,
1: EmbeddingTaskType.RETRIEVAL_QUERY,
"retrieval_query": EmbeddingTaskType.RETRIEVAL_QUERY,
"query": EmbeddingTaskType.RETRIEVAL_QUERY,
EmbeddingTaskType.RETRIEVAL_DOCUMENT: EmbeddingTaskType.RETRIEVAL_DOCUMENT,
2: EmbeddingTaskType.RETRIEVAL_DOCUMENT,
"retrieval_document": EmbeddingTaskType.RETRIEVAL_DOCUMENT,
"document": EmbeddingTaskType.RETRIEVAL_DOCUMENT,
EmbeddingTaskType.SEMANTIC_SIMILARITY: EmbeddingTaskType.SEMANTIC_SIMILARITY,
3: EmbeddingTaskType.SEMANTIC_SIMILARITY,
"semantic_similarity": EmbeddingTaskType.SEMANTIC_SIMILARITY,
"similarity": EmbeddingTaskType.SEMANTIC_SIMILARITY,
EmbeddingTaskType.CLASSIFICATION: EmbeddingTaskType.CLASSIFICATION,
4: EmbeddingTaskType.CLASSIFICATION,
"classification": EmbeddingTaskType.CLASSIFICATION,
EmbeddingTaskType.CLUSTERING: EmbeddingTaskType.CLUSTERING,
5: EmbeddingTaskType.CLUSTERING,
"clustering": EmbeddingTaskType.CLUSTERING,
6: EmbeddingTaskType.QUESTION_ANSWERING,
"question_answering": EmbeddingTaskType.QUESTION_ANSWERING,
"qa": EmbeddingTaskType.QUESTION_ANSWERING,
EmbeddingTaskType.QUESTION_ANSWERING: EmbeddingTaskType.QUESTION_ANSWERING,
7: EmbeddingTaskType.FACT_VERIFICATION,
"fact_verification": EmbeddingTaskType.FACT_VERIFICATION,
"verification": EmbeddingTaskType.FACT_VERIFICATION,
EmbeddingTaskType.FACT_VERIFICATION: EmbeddingTaskType.FACT_VERIFICATION,
}
def to_task_type(x: EmbeddingTaskTypeOptions) -> EmbeddingTaskType:
if isinstance(x, str):
x = x.lower()
return _EMBEDDING_TASK_TYPE[x]
try:
# python 3.12+
_batched = itertools.batched # type: ignore
except AttributeError:
T = TypeVar("T")
def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]:
if n < 1:
raise ValueError(
f"Invalid input: The batch size 'n' must be a positive integer. You entered: {n}. Please enter a number greater than 0."
)
batch = []
for item in iterable:
batch.append(item)
if len(batch) == n:
yield batch
batch = []
if batch:
yield batch
@overload
def embed_content(
model: model_types.BaseModelNameOptions,
content: content_types.ContentType,
task_type: EmbeddingTaskTypeOptions | None = None,
title: str | None = None,
output_dimensionality: int | None = None,
client: glm.GenerativeServiceClient | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> text_types.EmbeddingDict: ...
@overload
def embed_content(
model: model_types.BaseModelNameOptions,
content: Iterable[content_types.ContentType],
task_type: EmbeddingTaskTypeOptions | None = None,
title: str | None = None,
output_dimensionality: int | None = None,
client: glm.GenerativeServiceClient | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> text_types.BatchEmbeddingDict: ...
def embed_content(
model: model_types.BaseModelNameOptions,
content: content_types.ContentType | Iterable[content_types.ContentType],
task_type: EmbeddingTaskTypeOptions | None = None,
title: str | None = None,
output_dimensionality: int | None = None,
client: glm.GenerativeServiceClient = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict:
"""Calls the API to create embeddings for content passed in.
Args:
model:
Which [model](https://ai.google.dev/models/gemini#embedding) to
call, as a string or a `types.Model`.
content:
Content to embed.
task_type:
Optional task type for which the embeddings will be used. Can only
be set for `models/embedding-001`.
title:
An optional title for the text. Only applicable when task_type is
`RETRIEVAL_DOCUMENT`.
output_dimensionality:
Optional reduced dimensionality for the output embeddings. If set,
excessive values from the output embeddings will be truncated from
the end.
request_options:
Options for the request.
Return:
Dictionary containing the embedding (list of float values) for the
input content.
"""
model = model_types.make_model_name(model)
if request_options is None:
request_options = {}
if client is None:
client = get_default_generative_client()
if title and to_task_type(task_type) is not EmbeddingTaskType.RETRIEVAL_DOCUMENT:
raise ValueError(
f"Invalid task type: When a title is specified, the task must be of a 'retrieval document' type. Received task type: {task_type} and title: {title}."
)
if output_dimensionality and output_dimensionality < 0:
raise ValueError(
f"Invalid value: `output_dimensionality` must be a non-negative integer. Received: {output_dimensionality}."
)
if task_type:
task_type = to_task_type(task_type)
if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)):
result = {"embedding": []}
requests = (
protos.EmbedContentRequest(
model=model,
content=content_types.to_content(c),
task_type=task_type,
title=title,
output_dimensionality=output_dimensionality,
)
for c in content
)
for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE):
embedding_request = protos.BatchEmbedContentsRequest(model=model, requests=batch)
embedding_response = client.batch_embed_contents(
embedding_request,
**request_options,
)
embedding_dict = type(embedding_response).to_dict(embedding_response)
result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"])
return result
else:
embedding_request = protos.EmbedContentRequest(
model=model,
content=content_types.to_content(content),
task_type=task_type,
title=title,
output_dimensionality=output_dimensionality,
)
embedding_response = client.embed_content(
embedding_request,
**request_options,
)
embedding_dict = type(embedding_response).to_dict(embedding_response)
embedding_dict["embedding"] = embedding_dict["embedding"]["values"]
return embedding_dict
@overload
async def embed_content_async(
model: model_types.BaseModelNameOptions,
content: content_types.ContentType,
task_type: EmbeddingTaskTypeOptions | None = None,
title: str | None = None,
output_dimensionality: int | None = None,
client: glm.GenerativeServiceAsyncClient | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> text_types.EmbeddingDict: ...
@overload
async def embed_content_async(
model: model_types.BaseModelNameOptions,
content: Iterable[content_types.ContentType],
task_type: EmbeddingTaskTypeOptions | None = None,
title: str | None = None,
output_dimensionality: int | None = None,
client: glm.GenerativeServiceAsyncClient | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> text_types.BatchEmbeddingDict: ...
async def embed_content_async(
model: model_types.BaseModelNameOptions,
content: content_types.ContentType | Iterable[content_types.ContentType],
task_type: EmbeddingTaskTypeOptions | None = None,
title: str | None = None,
output_dimensionality: int | None = None,
client: glm.GenerativeServiceAsyncClient = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict:
"""Calls the API to create async embeddings for content passed in."""
model = model_types.make_model_name(model)
if request_options is None:
request_options = {}
if client is None:
client = get_default_generative_async_client()
if title and to_task_type(task_type) is not EmbeddingTaskType.RETRIEVAL_DOCUMENT:
raise ValueError(
f"Invalid task type: When a title is specified, the task must be of a 'retrieval document' type. Received task type: {task_type} and title: {title}."
)
if output_dimensionality and output_dimensionality < 0:
raise ValueError(
f"Invalid value: `output_dimensionality` must be a non-negative integer. Received: {output_dimensionality}."
)
if task_type:
task_type = to_task_type(task_type)
if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)):
result = {"embedding": []}
requests = (
protos.EmbedContentRequest(
model=model,
content=content_types.to_content(c),
task_type=task_type,
title=title,
output_dimensionality=output_dimensionality,
)
for c in content
)
for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE):
embedding_request = protos.BatchEmbedContentsRequest(model=model, requests=batch)
embedding_response = await client.batch_embed_contents(
embedding_request,
**request_options,
)
embedding_dict = type(embedding_response).to_dict(embedding_response)
result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"])
return result
else:
embedding_request = protos.EmbedContentRequest(
model=model,
content=content_types.to_content(content),
task_type=task_type,
title=title,
output_dimensionality=output_dimensionality,
)
embedding_response = await client.embed_content(
embedding_request,
**request_options,
)
embedding_dict = type(embedding_response).to_dict(embedding_response)
embedding_dict["embedding"] = embedding_dict["embedding"]["values"]
return embedding_dict