backend/matching-engine/services/text_to_image_match_service.py (92 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import random from typing import List, Optional import numpy as np import torch from google.cloud.aiplatform.matching_engine import matching_engine_index_endpoint from transformers import CLIPModel, CLIPTokenizerFast import tracer_helper from services.match_service import ( CodeInfo, Item, MatchResult, VertexAIMatchingEngineMatchService, ) tracer = tracer_helper.get_tracer(__name__) class TextToImageMatchService(VertexAIMatchingEngineMatchService[str]): @property def id(self) -> str: return self._id @property def name(self) -> str: """Name for this service that is shown on the frontend.""" return self._name @property def description(self) -> str: """Description for this service that is shown on the frontend.""" return self._description @property def allows_text_input(self) -> bool: """If true, this service allows text input.""" return True @property def code_info(self) -> Optional[CodeInfo]: """Info about code used to generate index.""" return self._code_info def __init__( self, id: str, name: str, description: str, prompts_file: str, model_id_or_path: str, index_endpoint_name: str, deployed_index_id: str, image_directory_uri: str, code_info: Optional[CodeInfo], ) -> None: self._id = id self._name = name self._description = description self.image_directory_uri = image_directory_uri self._code_info = code_info with open(prompts_file, "r") as f: prompts = f.readlines() self.prompts = [prompt.strip() for prompt in prompts] self.index_endpoint = ( matching_engine_index_endpoint.MatchingEngineIndexEndpoint( index_endpoint_name=index_endpoint_name ) ) self.deployed_index_id = deployed_index_id # if you have CUDA or MPS, set it to the active device like this self.device = ( "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") ) # we initialize a tokenizer, image processor, and the model itself self.tokenizer = CLIPTokenizerFast.from_pretrained(model_id_or_path) # self.processor = CLIPProcessor.from_pretrained(model_id) self.model = CLIPModel.from_pretrained(model_id_or_path).to(self.device) @tracer.start_as_current_span("get_suggestions") def get_suggestions(self, num_items: int = 60) -> List[Item]: """Get suggestions for search queries.""" return random.sample( [Item(id=word, text=word, image=None) for word in self.prompts], min(num_items, len(self.prompts)), ) @tracer.start_as_current_span("get_by_id") def get_by_id(self, id: str) -> Optional[str]: """Get an item by id.""" return f"{self.image_directory_uri}/{id}" @tracer.start_as_current_span("convert_text_to_embeddings") def convert_text_to_embeddings(self, target: str) -> Optional[List[float]]: # create transformer-readable tokens inputs = self.tokenizer(target, return_tensors="pt").to(self.device) # use CLIP to encode tokens into a meaningful embedding text_emb = self.model.get_text_features(**inputs) text_emb = text_emb.cpu().detach().numpy() if np.any(text_emb): return text_emb[0].tolist() else: return None @tracer.start_as_current_span("convert_match_neighbors_to_result") def convert_match_neighbors_to_result( self, matches: List[matching_engine_index_endpoint.MatchNeighbor] ) -> List[Optional[MatchResult]]: items = [self.get_by_id(match.id) for match in matches] return [ MatchResult(title=None, distance=match.distance, image=item) if item is not None else None for item, match in zip(items, matches) ]