backend/matching-engine/services/match_service.py (171 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import dataclasses
import functools
import logging
from typing import Generic, List, Optional, TypeVar
from google.cloud.aiplatform.matching_engine import (
matching_engine_index,
matching_engine_index_endpoint,
)
import tracer_helper
T = TypeVar("T")
logger = logging.getLogger(__name__)
tracer = tracer_helper.get_tracer(__name__)
@dataclasses.dataclass
class MatchResult:
distance: float
title: Optional[str] = None
description: Optional[str] = None
url: Optional[str] = None
image: Optional[str] = None
@dataclasses.dataclass
class Item:
text: str
id: Optional[str]
image: Optional[str]
@dataclasses.dataclass
class CodeInfo:
url: str
title: str
class MatchService(abc.ABC, Generic[T]):
@abc.abstractproperty
def id(self) -> str:
"""Unique identifier for this service."""
pass
@abc.abstractproperty
def name(self) -> str:
"""Name for this service that is shown on the frontend."""
pass
@abc.abstractproperty
def description(self) -> str:
"""Description for this service that is shown on the frontend."""
pass
@property
def allows_text_input(self) -> bool:
"""If true, this service allows text input."""
return False
@property
def allows_image_input(self) -> bool:
"""If true, this service allows text input."""
return False
@property
def code_info(self) -> Optional[CodeInfo]:
"""Info about code used to generate index."""
return None
def convert_image_to_embeddings(
self, image_file_local_path: str
) -> Optional[List[float]]:
"""Convert a given item to an embedding representation."""
raise NotImplementedError()
def convert_image_to_embeddings_remote(
self, image_file_remote_path: str
) -> Optional[List[float]]:
"""Convert a given item to an embedding representation."""
raise NotImplementedError()
def match_by_image(
self, image_file_local_path: str, num_neighbors: int
) -> List[MatchResult]:
raise NotImplementedError()
def match_by_image_remote(
self, image_file_remote_path: str, num_neighbors: int
) -> List[MatchResult]:
raise NotImplementedError()
@abc.abstractmethod
def get_suggestions(self, num_items: int = 60) -> List[Item]:
"""Get suggestions for search queries."""
pass
@abc.abstractmethod
def get_by_id(self, id: str) -> Optional[T]:
"""Get an item by id."""
pass
@abc.abstractmethod
def get_total_index_count(self) -> int:
"""Get total index count."""
pass
@abc.abstractmethod
def convert_text_to_embeddings(self, target: str) -> Optional[List[float]]:
"""Convert a given item to an embedding representation."""
pass
@abc.abstractmethod
def convert_match_neighbors_to_result(
self, matches: List[matching_engine_index_endpoint.MatchNeighbor]
) -> List[Optional[MatchResult]]:
pass
@abc.abstractmethod
def match_by_text(self, target: str, num_neighbors: int) -> List[MatchResult]:
pass
class VertexAIMatchingEngineMatchService(MatchService[T]):
index_endpoint: matching_engine_index_endpoint.MatchingEngineIndexEndpoint
deployed_index_id: str
is_public_index_endpoint: bool = True
@tracer.start_as_current_span("match_by_embeddings")
def match_by_embeddings(
self, embeddings: List[float], num_neighbors: int
) -> List[MatchResult]:
if embeddings is None:
raise ValueError("Embeddings could not be generated for: {target}")
logger.info(f"len(embeddings) = {len(embeddings)}")
if self.is_public_index_endpoint:
response = self.index_endpoint.find_neighbors(
deployed_index_id=self.deployed_index_id,
queries=[embeddings],
num_neighbors=num_neighbors,
)
else:
response = self.index_endpoint.match(
deployed_index_id=self.deployed_index_id,
queries=[embeddings],
num_neighbors=num_neighbors,
)
logger.info(f"index_endpoint.match completed")
matches_all = self.convert_match_neighbors_to_result(
matches=[match for matches in response for match in matches]
)
logger.info(f"matches converted")
matches_all_nonoptional: List[MatchResult] = [
match for match in matches_all if match is not None
]
logger.info(f"matches none filtered")
return matches_all_nonoptional
@tracer.start_as_current_span("match_by_text")
def match_by_text(self, target: str, num_neighbors: int) -> List[MatchResult]:
logger.info(f"match_by_text(target={target}, num_neighbors={num_neighbors})")
embeddings = self.convert_text_to_embeddings(target=target)
if embeddings is None:
raise ValueError("Embeddings could not be generated for: {target}")
return self.match_by_embeddings(
embeddings=embeddings, num_neighbors=num_neighbors
)
@tracer.start_as_current_span("match_by_image")
def match_by_image(
self, image_file_local_path: str, num_neighbors: int
) -> List[MatchResult]:
logger.info(
f"match_by_image(target={image_file_local_path}, num_neighbors={num_neighbors})"
)
embeddings = self.convert_image_to_embeddings(
image_file_local_path=image_file_local_path
)
if embeddings is None:
raise ValueError(
"Embeddings could not be generated for: {image_file_local_path}"
)
return self.match_by_embeddings(
embeddings=embeddings, num_neighbors=num_neighbors
)
@tracer.start_as_current_span("match_by_image_remote")
def match_by_image_remote(
self, image_file_remote_path: str, num_neighbors: int
) -> List[MatchResult]:
logger.info(
f"match_by_image(target={image_file_remote_path}, num_neighbors={num_neighbors})"
)
embeddings = self.convert_image_to_embeddings_remote(
image_file_remote_path=image_file_remote_path
)
if embeddings is None:
raise ValueError(
"Embeddings could not be generated for: {image_file_local_path}"
)
return self.match_by_embeddings(
embeddings=embeddings, num_neighbors=num_neighbors
)
@functools.lru_cache
@tracer.start_as_current_span("get_total_index_count")
def get_total_index_count(self) -> int:
return sum(
[
matching_engine_index.MatchingEngineIndex(
deployed_index.index
)._gca_resource.index_stats.vectors_count
for deployed_index in self.index_endpoint.deployed_indexes
]
)