retrieval_service/datastore/providers/firestore.py (454 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio from datetime import datetime, timedelta from typing import Any, Literal, Optional from google.cloud.firestore import AsyncClient # type: ignore from google.cloud.firestore_v1.async_collection import AsyncCollectionReference from google.cloud.firestore_v1.async_query import AsyncQuery from google.cloud.firestore_v1.base_query import FieldFilter from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.vector import Vector from pydantic import BaseModel import models from .. import datastore FIRESTORE_IDENTIFIER = "firestore" class Config(BaseModel, datastore.AbstractConfig): kind: Literal["firestore"] projectId: Optional[str] class Client(datastore.Client[Config]): __client: AsyncClient @datastore.classproperty def kind(cls): return FIRESTORE_IDENTIFIER def __init__(self, client: AsyncClient): self.__client = client self.__policies_collection = AsyncQuery(self.__client.collection("policies")) self.__amenities_collection = AsyncQuery(self.__client.collection("amenities")) @classmethod async def create(cls, config: Config) -> "Client": return cls(AsyncClient(project=config.projectId)) async def __delete_collections( self, collection_list: list[AsyncCollectionReference] ): # Checks if collection exists and deletes all documents delete_tasks = [] for collection_ref in collection_list: collection_exists = collection_ref.limit(1).stream() if not collection_exists: continue docs = collection_ref.stream() async for doc in docs: delete_tasks.append(asyncio.create_task(doc.reference.delete())) await asyncio.gather(*delete_tasks) async def parse_index_info(self, line: str) -> tuple[str, str]: # Extract collection and index-id from file path parts = line.split("/") collection_name = parts[-3] index_id = parts[-1] return collection_name, index_id async def __get_indices(self) -> dict[str, str]: list_vector_index_process = await asyncio.create_subprocess_exec( "gcloud", "alpha", "firestore", "indexes", "composite", "list", "--database=(default)", "--format=value(name)", # prints name field stdout=asyncio.subprocess.PIPE, ) # Capture output and ignore stderr stdout, __ = await list_vector_index_process.communicate() # Decode and format output index_lines = stdout.decode().strip().split("\n") indices = {} # Create a dict with collections and their corresponding vector index. for line in index_lines: if line: collection, index_id = await self.parse_index_info(line) indices[collection] = index_id return indices async def __delete_vector_index(self, indices: list[str]): # Check if the collection exists and deletes all indexes for index in indices: if index: delete_vector_index = await asyncio.create_subprocess_exec( "gcloud", "alpha", "firestore", "indexes", "composite", "delete", index, "--database=(default)", "--quiet", # Added to suppress delete warning ) await delete_vector_index.wait() async def __create_vector_index(self, collection_name: str): create_vector_index = await asyncio.create_subprocess_exec( "gcloud", "alpha", "firestore", "indexes", "composite", "create", f"--collection-group={collection_name}", "--query-scope=COLLECTION", '--field-config=field-path=embedding,vector-config={"dimension":768,"flat":"{}"}', "--database=(default)", ) await create_vector_index.wait() async def initialize_data( self, airports: list[models.Airport], amenities: list[models.Amenity], flights: list[models.Flight], policies: list[models.Policy], ) -> None: # Check if the collections already exist; if so, delete collections airports_ref = self.__client.collection("airports") amenities_ref = self.__client.collection("amenities") flights_ref = self.__client.collection("flights") policies_ref = self.__client.collection("policies") await self.__delete_collections( [airports_ref, amenities_ref, flights_ref, policies_ref] ) # Retrieve vector indexes and check if the collections already exist; if so, delete collections indices = await self.__get_indices() amenities_ref = indices.get("amenities", "") policies_ref = indices.get("policies", "") await self.__delete_vector_index([amenities_ref, policies_ref]) # Initialize collections create_airports_tasks = [] for airport in airports: create_airports_tasks.append( self.__client.collection("airports") .document(str(airport.id)) .set( { "iata": airport.iata, "name": airport.name, "city": airport.city, "country": airport.country, } ) ) await asyncio.gather(*create_airports_tasks) create_amenities_tasks = [] for amenity in amenities: create_amenities_tasks.append( self.__client.collection("amenities") .document(str(amenity.id)) .set( { "name": amenity.name, "description": amenity.description, "location": amenity.location, "terminal": amenity.terminal, "category": amenity.category, "hour": amenity.hour, # Firebase does not support datetime.time type "sunday_start_hour": ( str(amenity.sunday_start_hour) if amenity.sunday_start_hour else None ), "sunday_end_hour": ( str(amenity.sunday_end_hour) if amenity.sunday_end_hour else None ), "monday_start_hour": ( str(amenity.monday_start_hour) if amenity.monday_start_hour else None ), "monday_end_hour": ( str(amenity.monday_end_hour) if amenity.monday_end_hour else None ), "tuesday_start_hour": ( str(amenity.tuesday_start_hour) if amenity.tuesday_start_hour else None ), "tuesday_end_hour": ( str(amenity.tuesday_end_hour) if amenity.tuesday_end_hour else None ), "wednesday_start_hour": ( str(amenity.wednesday_start_hour) if amenity.wednesday_start_hour else None ), "wednesday_end_hour": ( str(amenity.wednesday_end_hour) if amenity.wednesday_end_hour else None ), "thursday_start_hour": ( str(amenity.thursday_start_hour) if amenity.thursday_start_hour else None ), "thursday_end_hour": ( str(amenity.thursday_end_hour) if amenity.thursday_end_hour else None ), "friday_start_hour": ( str(amenity.friday_start_hour) if amenity.friday_start_hour else None ), "friday_end_hour": ( str(amenity.friday_end_hour) if amenity.friday_end_hour else None ), "saturday_start_hour": ( str(amenity.saturday_start_hour) if amenity.saturday_start_hour else None ), "saturday_end_hour": ( str(amenity.saturday_end_hour) if amenity.saturday_end_hour else None ), "content": amenity.content, # Vector type does not support None value "embedding": Vector(amenity.embedding or []), } ) ) await asyncio.gather(*create_amenities_tasks) create_flights_tasks = [] for flight in flights: create_flights_tasks.append( self.__client.collection("flights") .document(str(flight.id)) .set( { "airline": flight.airline, "flight_number": flight.flight_number, "departure_airport": flight.departure_airport, "arrival_airport": flight.arrival_airport, "departure_time": flight.departure_time.strftime( "%Y-%m-%d %H:%M:%S" ), "arrival_time": flight.arrival_time.strftime( "%Y-%m-%d %H:%M:%S" ), "departure_gate": flight.departure_gate, "arrival_gate": flight.arrival_gate, } ) ) if len(create_flights_tasks) % 10000 == 0: # avoid gRPC batch write timeout error await asyncio.gather(*create_flights_tasks) create_flights_tasks.clear() await asyncio.gather(*create_flights_tasks) create_policies_tasks = [] for policy in policies: create_policies_tasks.append( self.__client.collection("policies") .document(str(policy.id)) .set( { "content": policy.content, # Vector type does not accept None value "embedding": Vector(policy.embedding or []), } ) ) await asyncio.gather(*create_policies_tasks) # Initialize single-field vector indexes await self.__create_vector_index("amenities") await self.__create_vector_index("policies") async def export_data( self, ) -> tuple[ list[models.Airport], list[models.Amenity], list[models.Flight], list[models.Policy], ]: airport_docs = self.__client.collection("airports").stream() amenities_docs = self.__client.collection("amenities").stream() flights_docs = self.__client.collection("flights").stream() policies_docs = self.__client.collection("policies").stream() airports = [] async for doc in airport_docs: airport_dict = doc.to_dict() airport_dict["id"] = doc.id airports.append(models.Airport.model_validate(airport_dict)) amenities = [] async for doc in amenities_docs: amenity_dict = doc.to_dict() amenity_dict["id"] = doc.id amenity_dict["embedding"] = list(amenity_dict["embedding"]) amenities.append(models.Amenity.model_validate(amenity_dict)) flights = [] async for doc in flights_docs: flight_dict = doc.to_dict() flight_dict["id"] = doc.id flights.append(models.Flight.model_validate(flight_dict)) policies = [] async for doc in policies_docs: policy_dict = doc.to_dict() policy_dict["id"] = doc.id policy_dict["embedding"] = list(policy_dict["embedding"]) policies.append(models.Policy.model_validate(policy_dict)) return airports, amenities, flights, policies async def get_airport_by_id( self, id: int ) -> tuple[Optional[models.Airport], Optional[str]]: query = self.__client.collection("airports").where( filter=FieldFilter("id", "==", id) ) airport_doc = await query.get() airport_dict = airport_doc.to_dict() | {"id": airport_doc.id} return models.Airport.model_validate(airport_dict), None async def get_airport_by_iata( self, iata: str ) -> tuple[Optional[models.Airport], Optional[str]]: query = self.__client.collection("airports").where( filter=FieldFilter("iata", "==", iata) ) airport_doc = await query.get() airport_dict = airport_doc.to_dict() | {"id": airport_doc.id} return models.Airport.model_validate(airport_dict), None async def search_airports( self, country: Optional[str] = None, city: Optional[str] = None, name: Optional[str] = None, ) -> tuple[list[models.Airport], Optional[str]]: query = self.__client.collection("airports") if country is not None: query = query.where("country", "==", country) if city is not None: query = query.where("city", "==", city) if name is not None: query = query.where("name", ">=", name).where("name", "<=", name + "\uf8ff") query = query.limit(10) docs = query.stream() airports = [] async for doc in docs: airport_dict = doc.to_dict() | {"id": doc.id} airports.append(models.Airport.model_validate(airport_dict)) return airports, None async def get_amenity( self, id: int ) -> tuple[Optional[models.Amenity], Optional[str]]: query = self.__client.collection("amenities").where( filter=FieldFilter("id", "==", id) ) amenity_doc = await query.get() amenity_dict = amenity_doc.to_dict() | {"id": amenity_doc.id} amenity_dict["embedding"] = list(amenity_dict["embedding"]) return models.Amenity.model_validate(amenity_dict), None async def amenities_search( self, query_embedding: list[float], similarity_threshold: float, top_k: int ) -> tuple[list[Any], Optional[str]]: # Using the same similarity metric to the embedding model's training method # produce the most accurate result query = self.__amenities_collection.find_nearest( vector_field="embedding", query_vector=Vector(query_embedding), distance_measure=DistanceMeasure.DOT_PRODUCT, limit=top_k, ) docs = query.stream() amenities = [] async for doc in docs: amenity_dict = { "id": doc.id, "category": doc.get("category"), "description": doc.get("description"), "hour": doc.get("hour"), "location": doc.get("location"), "name": doc.get("name"), "terminal": doc.get("terminal"), } amenities.append(amenity_dict) return amenities, None async def get_flight( self, flight_id: int ) -> tuple[Optional[models.Flight], Optional[str]]: query = self.__client.collection("flights").where( filter=FieldFilter("id", "==", flight_id) ) flight_doc = await query.get() flight_dict = flight_doc.to_dict() | {"id": flight_doc.id} return models.Flight.model_validate(flight_dict), None async def search_flights_by_number( self, airline: str, number: str, ) -> tuple[list[models.Flight], Optional[str]]: query = ( self.__client.collection("flights") .where(filter=FieldFilter("airline", "==", airline)) .where(filter=FieldFilter("flight_number", "==", number)) .limit(10) ) docs = query.stream() flights = [] async for doc in docs: flight_dict = doc.to_dict() | {"id": doc.id} flights.append(models.Flight.model_validate(flight_dict)) return flights, None async def search_flights_by_airports( self, date: str, departure_airport: Optional[str] = None, arrival_airport: Optional[str] = None, ) -> tuple[list[models.Flight], Optional[str]]: date_obj = datetime.strptime(date, "%Y-%m-%d").date() date_timestamp = datetime.combine(date_obj, datetime.min.time()) query = ( self.__client.collection("flights") .where("departure_time", ">=", date_timestamp) .where("departure_time", "<", date_timestamp + timedelta(days=1)) .limit(10) ) if departure_airport is None: query = query.where("departure_airport", "==", departure_airport) if arrival_airport is None: query = query.where("arrival_airport", "==", arrival_airport) docs = query.stream() flights = [] async for doc in docs: flight_dict = doc.to_dict() | {"id": doc.id} flights.append(models.Flight.model_validate(flight_dict)) return flights, None async def validate_ticket( self, airline: str, flight_number: str, departure_airport: str, departure_time: str, ) -> tuple[Optional[models.Flight], Optional[str]]: raise NotImplementedError("Not Implemented") async def insert_ticket( self, user_id: str, user_name: str, user_email: str, airline: str, flight_number: str, departure_airport: str, arrival_airport: str, departure_time: str, arrival_time: str, ): raise NotImplementedError("Not Implemented") async def list_tickets( self, user_id: str, ) -> tuple[list[Any], Optional[str]]: raise NotImplementedError("Not Implemented") async def policies_search( self, query_embedding: list[float], similarity_threshold: float, top_k: int ) -> tuple[list[str], Optional[str]]: query = self.__policies_collection.find_nearest( vector_field="embedding", query_vector=Vector(query_embedding), distance_measure=DistanceMeasure.DOT_PRODUCT, limit=top_k, ) policies = [] async for doc in query.stream(): policies.append(doc.get("content")) return policies, None async def close(self): self.__client.close()