backend/matching-engine/services/multimodal_embedding_client.py (61 lines of code) (raw):
# from absl import app
# from absl import flags
import base64
import requests
from google.cloud import aiplatform
from google.protobuf import struct_pb2
from typing import NamedTuple, Sequence, Optional
class EmbeddingResponse(NamedTuple):
text_embedding: Optional[Sequence[float]]
image_embedding: Optional[Sequence[float]]
def load_image_bytes(image_uri: str) -> bytes:
"""Load image bytes from a remote or local URI."""
if image_uri.startswith("http://") or image_uri.startswith("https://"):
response = requests.get(image_uri, stream=True)
if response.status_code == 200:
image_bytes = response.content
else:
image_bytes = open(image_uri, "rb").read()
return image_bytes
class MultimodalEmbeddingPredictionClient:
"""Wrapper around Prediction Service Client."""
def __init__(
self,
project_id: str,
location: str = "us-central1",
api_regional_endpoint: str = "us-central1-aiplatform.googleapis.com",
):
client_options = {"api_endpoint": api_regional_endpoint}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
self.client = aiplatform.gapic.PredictionServiceClient(
client_options=client_options
)
self.location = location
self.project_id = project_id
def get_embedding(
self, text: Optional[str] = None, image_file: Optional[str] = None
):
if not text and not image_file:
raise ValueError("At least one of text or image_file must be specified.")
# Load image file
image_bytes = None
if image_file:
image_bytes = load_image_bytes(image_file)
instance = struct_pb2.Struct()
if text:
instance.fields["text"].string_value = text
if image_bytes:
encoded_content = base64.b64encode(image_bytes).decode("utf-8")
image_struct = instance.fields["image"].struct_value
image_struct.fields["bytesBase64Encoded"].string_value = encoded_content
instances = [instance]
endpoint = (
f"projects/{self.project_id}/locations/{self.location}"
"/publishers/google/models/multimodalembedding@001"
)
response = self.client.predict(endpoint=endpoint, instances=instances)
text_embedding = None
if text:
text_emb_value: Sequence[float] = response.predictions[0]["textEmbedding"]
text_embedding = [v for v in text_emb_value]
image_embedding = None
if image_bytes:
image_emb_value: Sequence[float] = response.predictions[0]["imageEmbedding"]
image_embedding = [v for v in image_emb_value]
return EmbeddingResponse(
text_embedding=text_embedding, image_embedding=image_embedding
)