backend/matching-engine/services/multimodal_embedding_client.py (61 lines of code) (raw):

# from absl import app # from absl import flags import base64 import requests from google.cloud import aiplatform from google.protobuf import struct_pb2 from typing import NamedTuple, Sequence, Optional class EmbeddingResponse(NamedTuple): text_embedding: Optional[Sequence[float]] image_embedding: Optional[Sequence[float]] def load_image_bytes(image_uri: str) -> bytes: """Load image bytes from a remote or local URI.""" if image_uri.startswith("http://") or image_uri.startswith("https://"): response = requests.get(image_uri, stream=True) if response.status_code == 200: image_bytes = response.content else: image_bytes = open(image_uri, "rb").read() return image_bytes class MultimodalEmbeddingPredictionClient: """Wrapper around Prediction Service Client.""" def __init__( self, project_id: str, location: str = "us-central1", api_regional_endpoint: str = "us-central1-aiplatform.googleapis.com", ): client_options = {"api_endpoint": api_regional_endpoint} # Initialize client that will be used to create and send requests. # This client only needs to be created once, and can be reused for multiple requests. self.client = aiplatform.gapic.PredictionServiceClient( client_options=client_options ) self.location = location self.project_id = project_id def get_embedding( self, text: Optional[str] = None, image_file: Optional[str] = None ): if not text and not image_file: raise ValueError("At least one of text or image_file must be specified.") # Load image file image_bytes = None if image_file: image_bytes = load_image_bytes(image_file) instance = struct_pb2.Struct() if text: instance.fields["text"].string_value = text if image_bytes: encoded_content = base64.b64encode(image_bytes).decode("utf-8") image_struct = instance.fields["image"].struct_value image_struct.fields["bytesBase64Encoded"].string_value = encoded_content instances = [instance] endpoint = ( f"projects/{self.project_id}/locations/{self.location}" "/publishers/google/models/multimodalembedding@001" ) response = self.client.predict(endpoint=endpoint, instances=instances) text_embedding = None if text: text_emb_value: Sequence[float] = response.predictions[0]["textEmbedding"] text_embedding = [v for v in text_emb_value] image_embedding = None if image_bytes: image_emb_value: Sequence[float] = response.predictions[0]["imageEmbedding"] image_embedding = [v for v in image_emb_value] return EmbeddingResponse( text_embedding=text_embedding, image_embedding=image_embedding )