retrieval_service/datastore/providers/postgres.py (381 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from datetime import datetime
from ipaddress import IPv4Address, IPv6Address
from typing import Any, Literal, Optional
import asyncpg
from pgvector.asyncpg import register_vector
from pydantic import BaseModel
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
import models
from .. import datastore
from ..helpers import format_sql
POSTGRES_IDENTIFIER = "postgres"
class Config(BaseModel, datastore.AbstractConfig):
kind: Literal["postgres"]
host: IPv4Address | IPv6Address = IPv4Address("127.0.0.1")
port: int = 5432
user: str
password: str
database: str
class Client(datastore.Client[Config]):
__async_engine: AsyncEngine
@datastore.classproperty
def kind(cls):
return POSTGRES_IDENTIFIER
def __init__(self, async_engine: AsyncEngine):
self.__async_engine = async_engine
@classmethod
async def create(cls, config: Config) -> "Client":
async def getconn() -> asyncpg.Connection:
conn: asyncpg.Connection = await asyncpg.connection.connect(
host=str(config.host),
user=config.user,
password=config.password,
database=config.database,
port=config.port,
)
await register_vector(conn)
return conn
async_engine = create_async_engine(
"postgresql+asyncpg://",
async_creator=getconn,
)
if async_engine is None:
raise TypeError("async_engine not instantiated")
return cls(async_engine)
async def initialize_data(
self,
airports: list[models.Airport],
amenities: list[models.Amenity],
flights: list[models.Flight],
policies: list[models.Policy],
) -> None:
async with self.__async_engine.connect() as conn:
# If the table already exists, drop it to avoid conflicts
await conn.execute(text("DROP TABLE IF EXISTS airports CASCADE"))
# Create a new table
await conn.execute(
text(
"""
CREATE TABLE airports(
id INT PRIMARY KEY,
iata TEXT,
name TEXT,
city TEXT,
country TEXT
)
"""
)
)
# Insert all the data
await conn.execute(
text(
"""INSERT INTO airports VALUES (:id, :iata, :name, :city, :country)"""
),
[
{
"id": a.id,
"iata": a.iata,
"name": a.name,
"city": a.city,
"country": a.country,
}
for a in airports
],
)
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
# If the table already exists, drop it to avoid conflicts
await conn.execute(text("DROP TABLE IF EXISTS amenities CASCADE"))
# Create a new table
await conn.execute(
text(
"""
CREATE TABLE amenities(
id INT PRIMARY KEY,
name TEXT,
description TEXT,
location TEXT,
terminal TEXT,
category TEXT,
hour TEXT,
sunday_start_hour TIME,
sunday_end_hour TIME,
monday_start_hour TIME,
monday_end_hour TIME,
tuesday_start_hour TIME,
tuesday_end_hour TIME,
wednesday_start_hour TIME,
wednesday_end_hour TIME,
thursday_start_hour TIME,
thursday_end_hour TIME,
friday_start_hour TIME,
friday_end_hour TIME,
saturday_start_hour TIME,
saturday_end_hour TIME,
content TEXT NOT NULL,
embedding vector(768) NOT NULL
)
"""
)
)
# Insert all the data
await conn.execute(
text(
"""
INSERT INTO amenities VALUES (:id, :name, :description, :location,
:terminal, :category, :hour, :sunday_start_hour, :sunday_end_hour,
:monday_start_hour, :monday_end_hour, :tuesday_start_hour,
:tuesday_end_hour, :wednesday_start_hour, :wednesday_end_hour,
:thursday_start_hour, :thursday_end_hour, :friday_start_hour,
:friday_end_hour, :saturday_start_hour, :saturday_end_hour, :content, :embedding)
"""
),
[
{
"id": a.id,
"name": a.name,
"description": a.description,
"location": a.location,
"terminal": a.terminal,
"category": a.category,
"hour": a.hour,
"sunday_start_hour": a.sunday_start_hour,
"sunday_end_hour": a.sunday_end_hour,
"monday_start_hour": a.monday_start_hour,
"monday_end_hour": a.monday_end_hour,
"tuesday_start_hour": a.tuesday_start_hour,
"tuesday_end_hour": a.tuesday_end_hour,
"wednesday_start_hour": a.wednesday_start_hour,
"wednesday_end_hour": a.wednesday_end_hour,
"thursday_start_hour": a.thursday_start_hour,
"thursday_end_hour": a.thursday_end_hour,
"friday_start_hour": a.friday_start_hour,
"friday_end_hour": a.friday_end_hour,
"saturday_start_hour": a.saturday_start_hour,
"saturday_end_hour": a.saturday_end_hour,
"content": a.content,
"embedding": a.embedding,
}
for a in amenities
],
)
# If the table already exists, drop it to avoid conflicts
await conn.execute(text("DROP TABLE IF EXISTS flights CASCADE"))
# Create a new table
await conn.execute(
text(
"""
CREATE TABLE flights(
id INTEGER PRIMARY KEY,
airline TEXT,
flight_number TEXT,
departure_airport TEXT,
arrival_airport TEXT,
departure_time TIMESTAMP,
arrival_time TIMESTAMP,
departure_gate TEXT,
arrival_gate TEXT
)
"""
)
)
# Insert all the data
await conn.execute(
text(
"""
INSERT INTO flights VALUES (:id, :airline, :flight_number,
:departure_airport, :arrival_airport, :departure_time,
:arrival_time, :departure_gate, :arrival_gate)
"""
),
[
{
"id": f.id,
"airline": f.airline,
"flight_number": f.flight_number,
"departure_airport": f.departure_airport,
"arrival_airport": f.arrival_airport,
"departure_time": f.departure_time,
"arrival_time": f.arrival_time,
"departure_gate": f.departure_gate,
"arrival_gate": f.arrival_gate,
}
for f in flights
],
)
# If the table already exists, drop it to avoid conflicts
await conn.execute(text("DROP TABLE IF EXISTS tickets CASCADE"))
# Create a new table
await conn.execute(
text(
"""
CREATE TABLE tickets(
user_id TEXT,
user_name TEXT,
user_email TEXT,
airline TEXT,
flight_number TEXT,
departure_airport TEXT,
arrival_airport TEXT,
departure_time TIMESTAMP,
arrival_time TIMESTAMP
)
"""
)
)
# If the table already exists, drop it to avoid conflicts
await conn.execute(text("DROP TABLE IF EXISTS policies CASCADE"))
# Create a new table
await conn.execute(
text(
"""
CREATE TABLE policies(
id INT PRIMARY KEY,
content TEXT NOT NULL,
embedding vector(768) NOT NULL
)
"""
)
)
# Insert all the data
await conn.execute(
text(
"""
INSERT INTO policies VALUES (:id, :content, :embedding)
"""
),
[
{
"id": p.id,
"content": p.content,
"embedding": p.embedding,
}
for p in policies
],
)
await conn.commit()
async def export_data(
self,
) -> tuple[
list[models.Airport],
list[models.Amenity],
list[models.Flight],
list[models.Policy],
]:
async with self.__async_engine.connect() as conn:
airport_task = asyncio.create_task(
conn.execute(text("""SELECT * FROM airports ORDER BY id ASC"""))
)
amenity_task = asyncio.create_task(
conn.execute(text("""SELECT * FROM amenities ORDER BY id ASC"""))
)
flights_task = asyncio.create_task(
conn.execute(text("""SELECT * FROM flights ORDER BY id ASC"""))
)
policy_task = asyncio.create_task(
conn.execute(text("""SELECT * FROM policies ORDER BY id ASC"""))
)
airport_results = (await airport_task).mappings().fetchall()
amenity_results = (await amenity_task).mappings().fetchall()
flights_results = (await flights_task).mappings().fetchall()
policy_results = (await policy_task).mappings().fetchall()
airports = [models.Airport.model_validate(a) for a in airport_results]
amenities = [models.Amenity.model_validate(a) for a in amenity_results]
flights = [models.Flight.model_validate(f) for f in flights_results]
policies = [models.Policy.model_validate(p) for p in policy_results]
return airports, amenities, flights, policies
async def get_airport_by_id(
self, id: int
) -> tuple[Optional[models.Airport], Optional[str]]:
async with self.__async_engine.connect() as conn:
sql = """SELECT * FROM airports WHERE id=:id"""
s = text(sql)
params = {"id": id}
result = (await conn.execute(s, params)).mappings().fetchone()
if result is None:
return None, None
res = models.Airport.model_validate(result)
return res, format_sql(sql, params)
async def get_airport_by_iata(
self, iata: str
) -> tuple[Optional[models.Airport], Optional[str]]:
async with self.__async_engine.connect() as conn:
sql = """SELECT * FROM airports WHERE iata ILIKE :iata"""
s = text(sql)
params = {"iata": iata}
result = (await conn.execute(s, params)).mappings().fetchone()
if result is None:
return None, None
res = models.Airport.model_validate(result)
return res, format_sql(sql, params)
async def search_airports(
self,
country: Optional[str] = None,
city: Optional[str] = None,
name: Optional[str] = None,
) -> tuple[list[models.Airport], Optional[str]]:
async with self.__async_engine.connect() as conn:
sql = """
SELECT * FROM airports
WHERE (CAST(:country AS TEXT) IS NULL OR country ILIKE :country)
AND (CAST(:city AS TEXT) IS NULL OR city ILIKE :city)
AND (CAST(:name AS TEXT) IS NULL OR name ILIKE '%' || :name || '%')
LIMIT 10
"""
s = text(sql)
params = {
"country": country,
"city": city,
"name": name,
}
results = (await conn.execute(s, params)).mappings().fetchall()
res = [models.Airport.model_validate(r) for r in results]
return res, format_sql(sql, params)
async def get_amenity(
self, id: int
) -> tuple[Optional[models.Amenity], Optional[str]]:
async with self.__async_engine.connect() as conn:
sql = """
SELECT id, name, description, location, terminal, category, hour
FROM amenities WHERE id=:id
"""
s = text(sql)
params = {"id": id}
result = (await conn.execute(s, params)).mappings().fetchone()
if result is None:
return None, None
res = models.Amenity.model_validate(result)
return res, format_sql(sql, params)
async def amenities_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> tuple[list[Any], Optional[str]]:
async with self.__async_engine.connect() as conn:
sql = """
SELECT name, description, location, terminal, category, hour
FROM amenities
WHERE (embedding <=> :query_embedding) < :similarity_threshold
ORDER BY (embedding <=> :query_embedding)
LIMIT :top_k
"""
s = text(sql)
params = {
"query_embedding": query_embedding,
"similarity_threshold": similarity_threshold,
"top_k": top_k,
}
results = (await conn.execute(s, params)).mappings().fetchall()
res = [r for r in results]
return res, format_sql(sql, params)
async def get_flight(
self, flight_id: int
) -> tuple[Optional[models.Flight], Optional[str]]:
async with self.__async_engine.connect() as conn:
sql = """
SELECT * FROM flights
WHERE id = :flight_id
"""
s = text(sql)
params = {"flight_id": flight_id}
result = (await conn.execute(s, params)).mappings().fetchone()
if result is None:
return None, None
res = models.Flight.model_validate(result)
return res, format_sql(sql, params)
async def search_flights_by_number(
self,
airline: str,
number: str,
) -> tuple[list[models.Flight], Optional[str]]:
async with self.__async_engine.connect() as conn:
sql = """
SELECT * FROM flights
WHERE airline = :airline
AND flight_number = :number
LIMIT 10
"""
s = text(sql)
params = {
"airline": airline,
"number": number,
}
results = (await conn.execute(s, params)).mappings().fetchall()
res = [models.Flight.model_validate(r) for r in results]
return res, format_sql(sql, params)
async def search_flights_by_airports(
self,
date: str,
departure_airport: Optional[str] = None,
arrival_airport: Optional[str] = None,
) -> tuple[list[models.Flight], Optional[str]]:
async with self.__async_engine.connect() as conn:
sql = """
SELECT * FROM flights
WHERE (CAST(:departure_airport AS TEXT) IS NULL OR departure_airport ILIKE :departure_airport)
AND (CAST(:arrival_airport AS TEXT) IS NULL OR arrival_airport ILIKE :arrival_airport)
AND departure_time >= CAST(:datetime AS timestamp)
AND departure_time < CAST(:datetime AS timestamp) + interval '1 day'
LIMIT 10
"""
s = text(sql)
params = {
"departure_airport": departure_airport,
"arrival_airport": arrival_airport,
"datetime": datetime.strptime(date, "%Y-%m-%d"),
}
results = (await conn.execute(s, params)).mappings().fetchall()
res = [models.Flight.model_validate(r) for r in results]
return res, format_sql(sql, params)
async def validate_ticket(
self,
airline: str,
flight_number: str,
departure_airport: str,
departure_time: str,
) -> tuple[Optional[models.Flight], Optional[str]]:
departure_time_datetime = datetime.strptime(departure_time, "%Y-%m-%d %H:%M:%S")
async with self.__async_engine.connect() as conn:
sql = """
SELECT * FROM flights
WHERE airline ILIKE :airline
AND flight_number ILIKE :flight_number
AND departure_airport ILIKE :departure_airport
AND departure_time = :departure_time
"""
s = text(sql)
params = {
"airline": airline,
"flight_number": flight_number,
"departure_airport": departure_airport,
"departure_time": departure_time_datetime,
}
result = (await conn.execute(s, params)).mappings().fetchone()
if result is None:
return None, None
res = models.Flight.model_validate(result)
return res, format_sql(sql, params)
async def insert_ticket(
self,
user_id: str,
user_name: str,
user_email: str,
airline: str,
flight_number: str,
departure_airport: str,
arrival_airport: str,
departure_time: str,
arrival_time: str,
):
departure_time_datetime = datetime.strptime(departure_time, "%Y-%m-%d %H:%M:%S")
arrival_time_datetime = datetime.strptime(arrival_time, "%Y-%m-%d %H:%M:%S")
async with self.__async_engine.connect() as conn:
s = text(
"""
INSERT INTO tickets (
user_id,
user_name,
user_email,
airline,
flight_number,
departure_airport,
arrival_airport,
departure_time,
arrival_time
) VALUES (
:user_id,
:user_name,
:user_email,
:airline,
:flight_number,
:departure_airport,
:arrival_airport,
:departure_time,
:arrival_time
);
"""
)
params = {
"user_id": user_id,
"user_name": user_name,
"user_email": user_email,
"airline": airline,
"flight_number": flight_number,
"departure_airport": departure_airport,
"arrival_airport": arrival_airport,
"departure_time": departure_time_datetime,
"arrival_time": arrival_time_datetime,
}
result = (await conn.execute(s, params)).mappings()
await conn.commit()
if not result:
raise Exception("Ticket Insertion failure")
async def list_tickets(
self,
user_id: str,
) -> tuple[list[Any], Optional[str]]:
async with self.__async_engine.connect() as conn:
sql = """
SELECT user_name, airline, flight_number, departure_airport, arrival_airport, departure_time, arrival_time FROM tickets
WHERE user_id = :user_id
"""
s = text(sql)
params = {
"user_id": user_id,
}
results = (await conn.execute(s, params)).mappings().fetchall()
res = [r for r in results]
return res, format_sql(sql, params)
async def policies_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> tuple[list[str], Optional[str]]:
async with self.__async_engine.connect() as conn:
sql = """
SELECT content
FROM policies
WHERE (embedding <=> :query_embedding) < :similarity_threshold
ORDER BY (embedding <=> :query_embedding)
LIMIT :top_k
"""
s = text(sql)
params = {
"query_embedding": query_embedding,
"similarity_threshold": similarity_threshold,
"top_k": top_k,
}
results = (await conn.execute(s, params)).mappings().fetchall()
res = [r["content"] for r in results]
return res, format_sql(sql, params)
async def close(self):
await self.__async_engine.dispose()