retrieval_service/datastore/datastore.py (228 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import csv from abc import ABC, abstractmethod from typing import Any, Generic, List, Optional, TypeVar import models class AbstractConfig(ABC): kind: str C = TypeVar("C", bound=AbstractConfig) class classproperty: def __init__(self, func): self.fget = func def __get__(self, _, owner): return self.fget(owner) class Client(ABC, Generic[C]): @classproperty @abstractmethod def kind(cls): pass @classmethod @abstractmethod async def create(cls, config: C) -> "Client": pass async def load_dataset( self, airports_ds_path, amenities_ds_path, flights_ds_path, policies_ds_path ) -> tuple[ List[models.Airport], List[models.Amenity], List[models.Flight], List[models.Policy], ]: airports: List[models.Airport] = [] with open(airports_ds_path, "r") as f: reader = csv.DictReader(f, delimiter=",") airports = [models.Airport.model_validate(line) for line in reader] amenities: list[models.Amenity] = [] with open(amenities_ds_path, "r") as f: reader = csv.DictReader(f, delimiter=",") amenities = [models.Amenity.model_validate(line) for line in reader] flights: List[models.Flight] = [] with open(flights_ds_path, "r") as f: reader = csv.DictReader(f, delimiter=",") flights = [models.Flight.model_validate(line) for line in reader] policies: List[models.Policy] = [] with open(policies_ds_path, "r") as f: reader = csv.DictReader(f, delimiter=",") policies = [models.Policy.model_validate(line) for line in reader] return airports, amenities, flights, policies async def export_dataset( self, airports, amenities, flights, policies, airports_new_path, amenities_new_path, flights_new_path, policies_new_path, ) -> None: with open(airports_new_path, "w") as f: col_names = ["id", "iata", "name", "city", "country"] writer = csv.DictWriter(f, col_names, delimiter=",") writer.writeheader() for a in airports: writer.writerow(a.model_dump()) with open(amenities_new_path, "w") as f: col_names = [ "id", "name", "description", "location", "terminal", "category", "hour", "sunday_start_hour", "sunday_end_hour", "monday_start_hour", "monday_end_hour", "tuesday_start_hour", "tuesday_end_hour", "wednesday_start_hour", "wednesday_end_hour", "thursday_start_hour", "thursday_end_hour", "friday_start_hour", "friday_end_hour", "saturday_start_hour", "saturday_end_hour", "content", "embedding", ] writer = csv.DictWriter(f, col_names, delimiter=",") writer.writeheader() for a in amenities: writer.writerow(a.model_dump()) with open(flights_new_path, "w") as f: col_names = [ "id", "airline", "flight_number", "departure_airport", "arrival_airport", "departure_time", "arrival_time", "departure_gate", "arrival_gate", ] writer = csv.DictWriter(f, col_names, delimiter=",") writer.writeheader() for fl in flights: writer.writerow(fl.model_dump()) with open(policies_new_path, "w") as f: col_names = [ "id", "content", "embedding", ] writer = csv.DictWriter(f, col_names, delimiter=",") writer.writeheader() for p in policies: writer.writerow(p.model_dump()) @abstractmethod async def initialize_data( self, airports: list[models.Airport], amenities: list[models.Amenity], flights: list[models.Flight], policies: list[models.Policy], ) -> None: pass @abstractmethod async def export_data( self, ) -> tuple[ list[models.Airport], list[models.Amenity], list[models.Flight], list[models.Policy], ]: pass @abstractmethod async def get_airport_by_id( self, id: int ) -> tuple[Optional[models.Airport], Optional[str]]: raise NotImplementedError("Subclass should implement this!") @abstractmethod async def get_airport_by_iata( self, iata: str ) -> tuple[Optional[models.Airport], Optional[str]]: raise NotImplementedError("Subclass should implement this!") @abstractmethod async def search_airports( self, country: Optional[str] = None, city: Optional[str] = None, name: Optional[str] = None, ) -> tuple[list[models.Airport], Optional[str]]: raise NotImplementedError("Subclass should implement this!") @abstractmethod async def get_amenity( self, id: int ) -> tuple[Optional[models.Amenity], Optional[str]]: raise NotImplementedError("Subclass should implement this!") @abstractmethod async def amenities_search( self, query_embedding: list[float], similarity_threshold: float, top_k: int ) -> tuple[list[Any], Optional[str]]: raise NotImplementedError("Subclass should implement this!") @abstractmethod async def get_flight( self, flight_id: int ) -> tuple[Optional[models.Flight], Optional[str]]: raise NotImplementedError("Subclass should implement this!") @abstractmethod async def search_flights_by_number( self, airline: str, flight_number: str, ) -> tuple[list[models.Flight], Optional[str]]: raise NotImplementedError("Subclass should implement this!") @abstractmethod async def search_flights_by_airports( self, date, departure_airport: Optional[str] = None, arrival_airport: Optional[str] = None, ) -> tuple[list[models.Flight], Optional[str]]: raise NotImplementedError("Subclass should implement this!") @abstractmethod async def validate_ticket( self, airline: str, flight_number: str, departure_airport: str, departure_time: str, ) -> tuple[Optional[models.Flight], Optional[str]]: raise NotImplementedError("Subclass should implement this!") @abstractmethod async def insert_ticket( self, user_id: str, user_name: str, user_email: str, airline: str, flight_number: str, departure_airport: str, arrival_airport: str, departure_time: str, arrival_time: str, ): raise NotImplementedError("Subclass should implement this!") @abstractmethod async def list_tickets( self, user_id: str, ) -> tuple[list[Any], Optional[str]]: raise NotImplementedError("Subclass should implement this!") @abstractmethod async def policies_search( self, query_embedding: list[float], similarity_threshold: float, top_k: int ) -> tuple[list[str], Optional[str]]: raise NotImplementedError("Subclass should implement this!") @abstractmethod async def close(self): pass async def create(config: AbstractConfig) -> Client: for cls in Client.__subclasses__(): if config.kind == cls.kind: return await cls.create(config) # type: ignore raise TypeError(f"No clients of kind '{config.kind}'")