retrieval_service/datastore/datastore.py (228 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
from abc import ABC, abstractmethod
from typing import Any, Generic, List, Optional, TypeVar
import models
class AbstractConfig(ABC):
kind: str
C = TypeVar("C", bound=AbstractConfig)
class classproperty:
def __init__(self, func):
self.fget = func
def __get__(self, _, owner):
return self.fget(owner)
class Client(ABC, Generic[C]):
@classproperty
@abstractmethod
def kind(cls):
pass
@classmethod
@abstractmethod
async def create(cls, config: C) -> "Client":
pass
async def load_dataset(
self, airports_ds_path, amenities_ds_path, flights_ds_path, policies_ds_path
) -> tuple[
List[models.Airport],
List[models.Amenity],
List[models.Flight],
List[models.Policy],
]:
airports: List[models.Airport] = []
with open(airports_ds_path, "r") as f:
reader = csv.DictReader(f, delimiter=",")
airports = [models.Airport.model_validate(line) for line in reader]
amenities: list[models.Amenity] = []
with open(amenities_ds_path, "r") as f:
reader = csv.DictReader(f, delimiter=",")
amenities = [models.Amenity.model_validate(line) for line in reader]
flights: List[models.Flight] = []
with open(flights_ds_path, "r") as f:
reader = csv.DictReader(f, delimiter=",")
flights = [models.Flight.model_validate(line) for line in reader]
policies: List[models.Policy] = []
with open(policies_ds_path, "r") as f:
reader = csv.DictReader(f, delimiter=",")
policies = [models.Policy.model_validate(line) for line in reader]
return airports, amenities, flights, policies
async def export_dataset(
self,
airports,
amenities,
flights,
policies,
airports_new_path,
amenities_new_path,
flights_new_path,
policies_new_path,
) -> None:
with open(airports_new_path, "w") as f:
col_names = ["id", "iata", "name", "city", "country"]
writer = csv.DictWriter(f, col_names, delimiter=",")
writer.writeheader()
for a in airports:
writer.writerow(a.model_dump())
with open(amenities_new_path, "w") as f:
col_names = [
"id",
"name",
"description",
"location",
"terminal",
"category",
"hour",
"sunday_start_hour",
"sunday_end_hour",
"monday_start_hour",
"monday_end_hour",
"tuesday_start_hour",
"tuesday_end_hour",
"wednesday_start_hour",
"wednesday_end_hour",
"thursday_start_hour",
"thursday_end_hour",
"friday_start_hour",
"friday_end_hour",
"saturday_start_hour",
"saturday_end_hour",
"content",
"embedding",
]
writer = csv.DictWriter(f, col_names, delimiter=",")
writer.writeheader()
for a in amenities:
writer.writerow(a.model_dump())
with open(flights_new_path, "w") as f:
col_names = [
"id",
"airline",
"flight_number",
"departure_airport",
"arrival_airport",
"departure_time",
"arrival_time",
"departure_gate",
"arrival_gate",
]
writer = csv.DictWriter(f, col_names, delimiter=",")
writer.writeheader()
for fl in flights:
writer.writerow(fl.model_dump())
with open(policies_new_path, "w") as f:
col_names = [
"id",
"content",
"embedding",
]
writer = csv.DictWriter(f, col_names, delimiter=",")
writer.writeheader()
for p in policies:
writer.writerow(p.model_dump())
@abstractmethod
async def initialize_data(
self,
airports: list[models.Airport],
amenities: list[models.Amenity],
flights: list[models.Flight],
policies: list[models.Policy],
) -> None:
pass
@abstractmethod
async def export_data(
self,
) -> tuple[
list[models.Airport],
list[models.Amenity],
list[models.Flight],
list[models.Policy],
]:
pass
@abstractmethod
async def get_airport_by_id(
self, id: int
) -> tuple[Optional[models.Airport], Optional[str]]:
raise NotImplementedError("Subclass should implement this!")
@abstractmethod
async def get_airport_by_iata(
self, iata: str
) -> tuple[Optional[models.Airport], Optional[str]]:
raise NotImplementedError("Subclass should implement this!")
@abstractmethod
async def search_airports(
self,
country: Optional[str] = None,
city: Optional[str] = None,
name: Optional[str] = None,
) -> tuple[list[models.Airport], Optional[str]]:
raise NotImplementedError("Subclass should implement this!")
@abstractmethod
async def get_amenity(
self, id: int
) -> tuple[Optional[models.Amenity], Optional[str]]:
raise NotImplementedError("Subclass should implement this!")
@abstractmethod
async def amenities_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> tuple[list[Any], Optional[str]]:
raise NotImplementedError("Subclass should implement this!")
@abstractmethod
async def get_flight(
self, flight_id: int
) -> tuple[Optional[models.Flight], Optional[str]]:
raise NotImplementedError("Subclass should implement this!")
@abstractmethod
async def search_flights_by_number(
self,
airline: str,
flight_number: str,
) -> tuple[list[models.Flight], Optional[str]]:
raise NotImplementedError("Subclass should implement this!")
@abstractmethod
async def search_flights_by_airports(
self,
date,
departure_airport: Optional[str] = None,
arrival_airport: Optional[str] = None,
) -> tuple[list[models.Flight], Optional[str]]:
raise NotImplementedError("Subclass should implement this!")
@abstractmethod
async def validate_ticket(
self,
airline: str,
flight_number: str,
departure_airport: str,
departure_time: str,
) -> tuple[Optional[models.Flight], Optional[str]]:
raise NotImplementedError("Subclass should implement this!")
@abstractmethod
async def insert_ticket(
self,
user_id: str,
user_name: str,
user_email: str,
airline: str,
flight_number: str,
departure_airport: str,
arrival_airport: str,
departure_time: str,
arrival_time: str,
):
raise NotImplementedError("Subclass should implement this!")
@abstractmethod
async def list_tickets(
self,
user_id: str,
) -> tuple[list[Any], Optional[str]]:
raise NotImplementedError("Subclass should implement this!")
@abstractmethod
async def policies_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> tuple[list[str], Optional[str]]:
raise NotImplementedError("Subclass should implement this!")
@abstractmethod
async def close(self):
pass
async def create(config: AbstractConfig) -> Client:
for cls in Client.__subclasses__():
if config.kind == cls.kind:
return await cls.create(config) # type: ignore
raise TypeError(f"No clients of kind '{config.kind}'")