pyrit/embedding/_text_embedding.py (38 lines of code) (raw):
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import abc
from typing import Union
import tenacity
from openai import AzureOpenAI, OpenAI
from pyrit.models import (
EmbeddingData,
EmbeddingResponse,
EmbeddingSupport,
EmbeddingUsageInformation,
)
class _TextEmbedding(EmbeddingSupport, abc.ABC):
"""Text embedding base class"""
_client: Union[OpenAI, AzureOpenAI]
_model: str
def __init__(self) -> None:
super().__init__()
if not (hasattr(self, "_client") and hasattr(self, "_model")):
raise NotImplementedError(
"Text embedding client and model need to be provided by the implementing child class."
)
@tenacity.retry(wait=tenacity.wait_fixed(0.1), stop=tenacity.stop_after_delay(3))
def generate_text_embedding(self, text: str, **kwargs) -> EmbeddingResponse:
"""Generate text embedding
Args:
text: The text to generate the embedding for
**kwargs: Additional arguments to pass to the LLM client API
Returns:
The embedding response
"""
embedding_obj = self._client.embeddings.create(input=text, model=self._model, **kwargs)
embedding_response = EmbeddingResponse(
model=embedding_obj.model,
object=embedding_obj.object,
data=[
EmbeddingData(
embedding=embedding_obj.data[0].embedding,
index=embedding_obj.data[0].index,
object=embedding_obj.data[0].object,
)
],
usage=EmbeddingUsageInformation(
prompt_tokens=embedding_obj.usage.prompt_tokens,
total_tokens=embedding_obj.usage.total_tokens,
),
)
return embedding_response