backend-apis/deployment_scripts/vertex_vector_generate_embeddings.py (107 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import base64
import json
import os
import typing
import numpy as np
import requests
from google.cloud import aiplatform, storage
from google.protobuf import struct_pb2
class EmbeddingResponse(typing.NamedTuple):
text_embedding: typing.Sequence[float]
image_embedding: typing.Sequence[float]
class EmbeddingPredictionClient:
"""Wrapper around Prediction Service Client."""
def __init__(
self,
project: str,
location: str = "us-central1",
api_regional_endpoint: str = "us-central1-aiplatform.googleapis.com",
):
client_options = {"api_endpoint": api_regional_endpoint}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
self.client = aiplatform.gapic.PredictionServiceClient(
client_options=client_options
)
self.location = location
self.project = project
def get_embedding(
self, text: str | None = None, image_bytes: bytes | None = None
):
if not text and not image_bytes:
raise ValueError(
"At least one of text or image_bytes must be specified."
)
instance = struct_pb2.Struct()
if text:
instance.fields["text"].string_value = text
if image_bytes:
encoded_content = base64.b64encode(image_bytes).decode("utf-8")
image_struct = instance.fields["image"].struct_value
image_struct.fields[
"bytesBase64Encoded"
].string_value = encoded_content
instances = [instance]
endpoint = (
f"projects/{self.project}/locations/{self.location}"
"/publishers/google/models/multimodalembedding@001"
)
response = self.client.predict(endpoint=endpoint, instances=instances)
text_embedding = None
if text:
text_emb_value = response.predictions[0]["textEmbedding"]
text_embedding = [v for v in text_emb_value]
image_embedding = None
if image_bytes:
image_emb_value = response.predictions[0]["imageEmbedding"]
image_embedding = [v for v in image_emb_value]
return EmbeddingResponse(
text_embedding=text_embedding, image_embedding=image_embedding
)
def reduce_embedding_dimension(
vector_text=[],
vector_image=[],
) -> list:
if vector_image and vector_text:
matrix = np.array([vector_text, vector_image])
max_pooled_rows = np.sum(matrix, axis=0)
else:
max_pooled_rows = np.array(vector_text or vector_image)
return list(max_pooled_rows)
def image_text_to_embedding(text: str, image_uri: str) -> list:
image_contents = requests.get(image_uri).content
response = embeddings_client.get_embedding(
text=text[:1020], image_bytes=image_contents
)
reduced_vector = reduce_embedding_dimension(
vector_image=response.image_embedding,
vector_text=response.text_embedding,
)
return reduced_vector
def generate_metadata_upsert(input_dir: str, output_dir: str):
with open(
os.path.join(input_dir, "images_title_description.jsonl"), "r"
) as f:
products = [json.loads(p) for p in f.readlines()]
metadata = {"datapoints": []}
for product in products:
id = product["id"] if product["id"] != "0" else "1000"
print(id)
feature_vector = image_text_to_embedding(
text=product["title"] + " " + product["description"],
image_uri=product["uri"],
)
metadata["datapoints"].append(
{"datapoint_id": id, "feature_vector": feature_vector}
)
with open(os.path.join(output_dir, "vector_metadata.json"), "a") as f:
f.write(json.dumps(metadata))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("project_name")
parser.add_argument("input_dir")
parser.add_argument("output_dir")
args = parser.parse_args()
embeddings_client = EmbeddingPredictionClient(project=args.project_name)
storage_client = storage.Client()
generate_metadata_upsert(
input_dir=args.input_dir, output_dir=args.output_dir
)