backend/matching-engine/services/multimodal_text_to_image_match_service.py (223 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from typing import Dict, List, Optional, TypeVar
import google.auth
import google.auth.transport.requests
import redis
import requests
from google.cloud.aiplatform.matching_engine import matching_engine_index_endpoint
from services.multimodal_embedding_client import MultimodalEmbeddingPredictionClient
import storage_helper
import tracer_helper
from services.match_service import (
CodeInfo,
Item,
MatchResult,
VertexAIMatchingEngineMatchService,
)
tracer = tracer_helper.get_tracer(__name__)
DESTINATION_BLOB_NAME = "multimodal_text_to_image"
def get_access_token() -> str:
# Get default access token
creds, _ = google.auth.default()
# creds.valid is False, and creds.token is None
# Need to refresh credentials to populate those
auth_req = google.auth.transport.requests.Request()
creds.refresh(auth_req)
access_token = creds.token
if access_token is None or len(access_token) == 0:
raise RuntimeError("No access token found")
return access_token
T = TypeVar("T")
class MultimodalTextToImageMatchService(VertexAIMatchingEngineMatchService[T]):
@property
def id(self) -> str:
return self._id
@property
def name(self) -> str:
"""Name for this service that is shown on the frontend."""
return self._name
@property
def description(self) -> str:
"""Description for this service that is shown on the frontend."""
return self._description
@property
def allows_text_input(self) -> bool:
"""If true, this service allows text input."""
return self._allows_text_input
@property
def allows_image_input(self) -> bool:
"""If true, this service allows image input."""
return self._allows_image_input
@property
def code_info(self) -> Optional[CodeInfo]:
"""Info about code used to generate index."""
return self._code_info
def __init__(
self,
id: str,
name: str,
description: str,
allows_text_input: bool,
allows_image_input: bool,
index_endpoint_name: str,
deployed_index_id: str,
project_id: str,
gcs_bucket: str,
is_public_index_endpoint: bool,
prompts_texts_file: Optional[str] = None,
prompt_images_file: Optional[str] = None,
code_info: Optional[CodeInfo] = None,
) -> None:
self._id = id
self._name = name
self._description = description
self._code_info = code_info
self.project_id = project_id
self._allows_text_input = allows_text_input
self._allows_image_input = allows_image_input
self.gcs_bucket = gcs_bucket
if prompts_texts_file:
with open(prompts_texts_file, "r") as f:
prompts = f.readlines()
self.prompt_texts = [prompt.strip() for prompt in prompts]
else:
self.prompt_texts = []
if prompt_images_file:
with open(prompt_images_file, "r") as f:
prompt_images = f.readlines()
self.prompt_images = [prompt.strip() for prompt in prompt_images]
else:
self.prompt_images = []
self.index_endpoint = (
matching_engine_index_endpoint.MatchingEngineIndexEndpoint(
index_endpoint_name=index_endpoint_name
)
)
self.deployed_index_id = deployed_index_id
self.client = MultimodalEmbeddingPredictionClient(project_id=self.project_id)
self.is_public_index_endpoint = is_public_index_endpoint
@tracer.start_as_current_span("get_suggestions")
def get_suggestions(self, num_items: int = 60) -> List[Item]:
"""Get suggestions for search queries."""
text_prompts = (
[Item(id=word, text=word, image=None) for word in self.prompt_texts]
if self.allows_text_input
else []
)
image_prompts = (
[
Item(id=image_url, text="", image=image_url)
for image_url in self.prompt_images
]
if self.allows_image_input
else []
)
prompts = text_prompts + image_prompts
return random.sample(
prompts,
min(num_items, len(prompts)),
)
def encode_image_to_embeddings(self, image_uri: str) -> List[float]:
try:
return self.client.get_embedding(
text=None, image_file=image_uri
).image_embedding
except Exception as ex:
raise RuntimeError("Error getting embedding.")
def encode_text_to_embeddings(self, text: str) -> List[float]:
try:
return self.client.get_embedding(text=text, image_file=None).text_embedding
except Exception as ex:
raise RuntimeError("Error getting embedding.")
@tracer.start_as_current_span("convert_text_to_embeddings")
def convert_text_to_embeddings(self, target: str) -> Optional[List[float]]:
return self.encode_text_to_embeddings(text=target)
@tracer.start_as_current_span("convert_image_to_embeddings")
def convert_image_to_embeddings(
self, image_file_local_path: str
) -> Optional[List[float]]:
"""Convert a given item to an embedding representation."""
# Upload image file
image_uri = storage_helper.upload_blob(
source_file_name=image_file_local_path,
bucket_name=self.gcs_bucket,
destination_blob_name=DESTINATION_BLOB_NAME,
)
# Convert GCS path to HTTP path
image_uri_http = f"https://storage.googleapis.com/{image_uri[5:]}"
return self.encode_image_to_embeddings(image_uri=image_uri_http)
@tracer.start_as_current_span("convert_image_to_embeddings_remote")
def convert_image_to_embeddings_remote(
self, image_file_remote_path: str
) -> Optional[List[float]]:
"""Convert a given item to an embedding representation."""
return self.encode_image_to_embeddings(
image_uri=image_file_remote_path,
)
class MercariTextToImageMatchService(MultimodalTextToImageMatchService[Dict[str, str]]):
def __init__(
self,
id: str,
name: str,
description: str,
allows_text_input: bool,
allows_image_input: bool,
index_endpoint_name: str,
deployed_index_id: str,
project_id: str,
redis_host: str, # Redis host to get data about a match id
redis_port: int, # Redis port to get data about a match id
gcs_bucket: str,
is_public_index_endpoint: bool,
prompts_texts_file: Optional[str] = None,
prompt_images_file: Optional[str] = None,
code_info: Optional[CodeInfo] = None,
) -> None:
super().__init__(
id=id,
name=name,
description=description,
code_info=code_info,
project_id=project_id,
allows_text_input=allows_text_input,
allows_image_input=allows_image_input,
gcs_bucket=gcs_bucket,
prompts_texts_file=prompts_texts_file,
prompt_images_file=prompt_images_file,
index_endpoint_name=index_endpoint_name,
deployed_index_id=deployed_index_id,
is_public_index_endpoint=is_public_index_endpoint,
)
self.redis_client = redis.StrictRedis(host=redis_host, port=redis_port)
@tracer.start_as_current_span("get_by_id")
def get_by_id(self, id: str) -> Optional[Dict[str, str]]:
"""Get an item by id."""
retrieved = self.redis_client.hgetall(str(id))
if retrieved is not None:
# Convert the byte strings to regular strings
return {key.decode(): value.decode() for key, value in retrieved.items()}
else:
return None
@tracer.start_as_current_span("convert_match_neighbors_to_result")
def convert_match_neighbors_to_result(
self, matches: List[matching_engine_index_endpoint.MatchNeighbor]
) -> List[Optional[MatchResult]]:
items = [self.get_by_id(match.id) for match in matches]
return [
MatchResult(
title=item["name"],
description=item["description"],
distance=max(0, 1 - match.distance),
url=item["url"],
image=item["img_url"],
)
if item is not None
else None
for item, match in zip(items, matches)
]
class RoomsTextToImageMatchService(MultimodalTextToImageMatchService[str]):
@tracer.start_as_current_span("get_by_id")
def get_by_id(self, id: str) -> Optional[str]:
"""Get an item by id."""
return id
@tracer.start_as_current_span("convert_match_neighbors_to_result")
def convert_match_neighbors_to_result(
self, matches: List[matching_engine_index_endpoint.MatchNeighbor]
) -> List[Optional[MatchResult]]:
items = [self.get_by_id(match.id) for match in matches]
return [
MatchResult(
title=None,
description=None,
distance=max(0, 1 - match.distance),
url=None,
image=f"https://storage.googleapis.com/ai-demos-us-central1/interior_images/mit_indoor/{item}",
)
if item is not None
else None
for item, match in zip(items, matches)
]