experiments/legacy/backend/embeddings.py (68 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Invoke Vertex Embedding API.""" import base64 from functools import cache import time import logging from typing import NamedTuple, Optional, Sequence from google.cloud import aiplatform from google.protobuf import struct_pb2 import config class EmbeddingResponse(NamedTuple): text_embedding: Sequence[float] image_embedding: Sequence[float] class EmbeddingPredictionClient: """Wrapper around Prediction Service Client.""" def __init__(self, project : str, location : str = "us-central1", api_regional_endpoint: str = "us-central1-aiplatform.googleapis.com"): client_options = {"api_endpoint": api_regional_endpoint} # Initialize client that will be used to create and send requests. # This client only needs to be created once, and can be reused for multiple requests. self.client = aiplatform.gapic.PredictionServiceClient(client_options=client_options) self.location = location self.project = project def get_embedding(self, text: Optional[str] = None, image: Optional[str] = None, base64: bool = False): """Invoke Vertex multimodal embedding API. You can pass text and/or image. If neither is passed will raise exception Args: text: text to embed image: can be local file path, GCS URI or base64 encoded image base64: True indicates image is base64. False (default) will be interpreted as image path (either local or GCS) Returns: named tuple with the following attributes: text_embedding: 1408 dimension vector of type Sequence[float] image_embedding: 1408 dimension vector of type Sequence[float] OR None if no image provide """ if not text and not image: raise ValueError('At least one of text or image_bytes must be specified.') instance = struct_pb2.Struct() if text: if len(text) >= 1024: logging.warning('Text must be less than 1024 characters. Truncating text.') text = text[:1023] instance.fields['text'].string_value = text if image: image_struct = instance.fields['image'].struct_value if base64: image_struct.fields['bytesBase64Encoded'].string_value = image elif image.lower().startswith('gs://'): image_struct.fields['gcsUri'].string_value = image else: with open(image, "rb") as f: image_bytes = f.read() encoded_content = base64.b64encode(image_bytes).decode("utf-8") image_struct.fields['bytesBase64Encoded'].string_value = encoded_content instances = [instance] endpoint = (f"projects/{self.project}/locations/{self.location}" "/publishers/google/models/multimodalembedding@001") response = self.client.predict(endpoint=endpoint, instances=instances) text_embedding = None if text: text_emb_value = response.predictions[0]['textEmbedding'] text_embedding = [v for v in text_emb_value] image_embedding = None if image: image_emb_value = response.predictions[0]['imageEmbedding'] image_embedding = [v for v in image_emb_value] return EmbeddingResponse( text_embedding=text_embedding, image_embedding=image_embedding) @cache def get_client(project): return EmbeddingPredictionClient(project) def embed( text: str, image: Optional[str] = None, base64: bool = False, project: str = config.PROJECT) -> EmbeddingResponse: """Invoke vertex multimodal embedding API. Args: text: text to embed image: can be local file path, GCS URI or base64 encoded image base64: True indicates image is base64. False (default) will be interpreted as image path (either local or GCS) project: GCP Project ID Returns: named tuple with the following attributes: text_embedding: 1408 dimension vector of type Sequence[float] image_embedding: 1408 dimension vector of type Sequence[float] OR None if no image provide """ client = get_client(project) start = time.time() response = client.get_embedding(text=text, image=image, base64=base64) end = time.time() return response