backend/matching-engine/main.py (176 lines of code) (raw):

# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import logging import shutil import tempfile from typing import Annotated, Any, Dict, List, Optional, Tuple from fastapi import FastAPI, File, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from pydantic import BaseModel import register_services import tracer_helper from services import match_service logger = logging.getLogger(__name__) tracer = tracer_helper.get_tracer(__name__) app = FastAPI() match_service_registry: Dict[ str, match_service.MatchService ] = register_services.register_services() origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.add_middleware(GZipMiddleware, minimum_size=1000) class GetItemsResponse(BaseModel): items: List[match_service.Item] @tracer.start_as_current_span(f"/match-registry") @app.get("/match-registry") async def get_match_registry(): return [ { "id": service.id, "name": service.name, "description": service.description, "allowsTextInput": service.allows_text_input, "allowsImageInput": service.allows_image_input, "code": service.code_info, } for service in match_service_registry.values() ] @app.get("/items/{match_service_id}") async def get_items(match_service_id: str): with tracer.start_as_current_span(f"/items/{match_service_id}"): service = match_service_registry.get(match_service_id) if service: return GetItemsResponse(items=service.get_suggestions()) else: raise HTTPException( status_code=400, detail=f"Match service not found for id: {match_service_id}", ) class MatchByIdRequest(BaseModel): id: str numNeighbors: int = 10 class MatchByTextRequest(BaseModel): text: str numNeighbors: int = 10 @dataclasses.dataclass class MatchResponse: totalIndexCount: int results: List[match_service.MatchResult] @app.post("/match-by-id/{match_service_id}") async def match_by_id( match_service_id: str, request: MatchByIdRequest ) -> MatchResponse: with tracer.start_as_current_span(f"/match-by-id/{match_service_id}"): service = match_service_registry.get(match_service_id) if not service: raise HTTPException( status_code=400, detail=f"Match service not found for id: {match_service_id}", ) item = service.get_by_id(id=request.id) if item is not None: try: results = service.match_by_text( target=item, num_neighbors=request.numNeighbors ) except Exception as ex: logger.error(ex) raise HTTPException( status_code=500, detail=f"There was an error getting matches" ) else: raise HTTPException( status_code=404, detail=f"Item not found for id: {request.id}" ) return MatchResponse( totalIndexCount=service.get_total_index_count(), results=results ) @app.post("/match-by-text/{match_service_id}") async def match_by_text( match_service_id: str, request: MatchByTextRequest ) -> MatchResponse: with tracer.start_as_current_span(f"/match-by-text/{match_service_id}"): service = match_service_registry.get(match_service_id) if not service: raise HTTPException( status_code=400, detail=f"Match service not found for id: {match_service_id}", ) try: results = service.match_by_text( target=request.text, num_neighbors=request.numNeighbors ) return MatchResponse( totalIndexCount=service.get_total_index_count(), results=results ) except Exception as ex: logger.error(ex) raise HTTPException( status_code=500, detail=f"There was an error getting matches" ) @app.post("/match-by-image/{match_service_id}") async def match_by_image( match_service_id: str, image: UploadFile, numNeighbors: int = 10 ) -> MatchResponse: with tracer.start_as_current_span(f"/match-by-image/{match_service_id}"): service = match_service_registry.get(match_service_id) if not service: raise HTTPException( status_code=400, detail=f"Match service not found for id: {match_service_id}", ) if image.filename is None: raise HTTPException( status_code=400, detail=f"No image uploaded", ) try: with tempfile.NamedTemporaryFile() as f: shutil.copyfileobj(image.file, f) results = service.match_by_image( image_file_local_path=f.name, num_neighbors=numNeighbors, ) return MatchResponse( totalIndexCount=service.get_total_index_count(), results=results ) except Exception as ex: logger.error(ex) raise HTTPException( status_code=500, detail=f"There was an error getting matches" ) finally: image.file.close() class MatchByImageUrlRequest(BaseModel): imageUrl: str numNeighbors: int = 10 @app.post("/match-by-image-url/{match_service_id}") async def match_by_image_url( match_service_id: str, request: MatchByImageUrlRequest ) -> MatchResponse: with tracer.start_as_current_span(f"/match-by-image-url/{match_service_id}"): service = match_service_registry.get(match_service_id) if not service: raise HTTPException( status_code=400, detail=f"Match service not found for id: {match_service_id}", ) try: # Use remote image url results = service.match_by_image_remote( image_file_remote_path=request.imageUrl, num_neighbors=request.numNeighbors, ) return MatchResponse( totalIndexCount=service.get_total_index_count(), results=results ) except Exception as ex: logger.error(ex) raise HTTPException( status_code=500, detail=f"There was an error getting matches" )