src/retrieval_service/datastore/providers/cloudsql_postgres.py (282 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from datetime import datetime
from typing import Any, Dict, Literal, Optional
import asyncpg
import sqlalchemy
from google.cloud.sql.connector import Connector, IPTypes
from pgvector.asyncpg import register_vector
from pydantic import BaseModel
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
import models
from .. import datastore
POSTGRES_IDENTIFIER = "cloudsql-postgres"
class Config(BaseModel, datastore.AbstractConfig):
kind: Literal["cloudsql-postgres"]
project: str
region: str
instance: str
user: str
password: str
database: str
class Client(datastore.Client[Config]):
__pool: AsyncEngine
@datastore.classproperty
def kind(cls):
return "cloudsql-postgres"
def __init__(self, pool: AsyncEngine):
self.__pool = pool
@classmethod
async def create(cls, config: Config) -> "Client":
loop = asyncio.get_running_loop()
async def getconn() -> asyncpg.Connection:
async with Connector(loop=loop) as connector:
conn: asyncpg.Connection = await connector.connect_async(
# Cloud SQL instance connection name
f"{config.project}:{config.region}:{config.instance}",
"asyncpg",
user=f"{config.user}",
password=f"{config.password}",
db=f"{config.database}",
ip_type=IPTypes.PSC,
)
await conn.execute('CREATE EXTENSION IF NOT EXISTS google_ml_integration')
await conn.execute('CREATE EXTENSION IF NOT EXISTS vector')
await register_vector(conn)
return conn
pool = create_async_engine(
"postgresql+asyncpg://",
async_creator=getconn,
)
if pool is None:
raise TypeError("pool not instantiated")
return cls(pool)
async def initialize_data(
self,
airports: list[models.Airport],
amenities: list[models.Amenity],
flights: list[models.Flight],
) -> None:
async with self.__pool.connect() as conn:
# If the table already exists, drop it to avoid conflicts
await conn.execute(text("DROP TABLE IF EXISTS airports CASCADE"))
# Create a new table
await conn.execute(
text(
"""
CREATE TABLE airports(
id INT PRIMARY KEY,
iata TEXT,
name TEXT,
city TEXT,
country TEXT
)
"""
)
)
# Insert all the data
await conn.execute(
text(
"""INSERT INTO airports VALUES (:id, :iata, :name, :city, :country)"""
),
[
{
"id": a.id,
"iata": a.iata,
"name": a.name,
"city": a.city,
"country": a.country,
}
for a in airports
],
)
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS google_ml_integration"))
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
# If the table already exists, drop it to avoid conflicts
await conn.execute(text("DROP TABLE IF EXISTS amenities CASCADE"))
# Create a new table
await conn.execute(
text(
"""
CREATE TABLE amenities(
id INT PRIMARY KEY,
name TEXT,
description TEXT,
location TEXT,
terminal TEXT,
category TEXT,
hour TEXT,
sunday_start_hour TIME,
sunday_end_hour TIME,
monday_start_hour TIME,
monday_end_hour TIME,
tuesday_start_hour TIME,
tuesday_end_hour TIME,
wednesday_start_hour TIME,
wednesday_end_hour TIME,
thursday_start_hour TIME,
thursday_end_hour TIME,
friday_start_hour TIME,
friday_end_hour TIME,
saturday_start_hour TIME,
saturday_end_hour TIME,
content TEXT NOT NULL,
embedding vector(768) NOT NULL
)
"""
)
)
# Insert all the data
await conn.execute(
text(
"""
INSERT INTO amenities VALUES (:id, :name, :description, :location,
:terminal, :category, :hour, :sunday_start_hour, :sunday_end_hour,
:monday_start_hour, :monday_end_hour, :tuesday_start_hour,
:tuesday_end_hour, :wednesday_start_hour, :wednesday_end_hour,
:thursday_start_hour, :thursday_end_hour, :friday_start_hour,
:friday_end_hour, :saturday_start_hour, :saturday_end_hour, :content, :embedding)
"""
),
[
{
"id": a.id,
"name": a.name,
"description": a.description,
"location": a.location,
"terminal": a.terminal,
"category": a.category,
"hour": a.hour,
"sunday_start_hour": a.sunday_start_hour,
"sunday_end_hour": a.sunday_end_hour,
"monday_start_hour": a.monday_start_hour,
"monday_end_hour": a.monday_end_hour,
"tuesday_start_hour": a.tuesday_start_hour,
"tuesday_end_hour": a.tuesday_end_hour,
"wednesday_start_hour": a.wednesday_start_hour,
"wednesday_end_hour": a.wednesday_end_hour,
"thursday_start_hour": a.thursday_start_hour,
"thursday_end_hour": a.thursday_end_hour,
"friday_start_hour": a.friday_start_hour,
"friday_end_hour": a.friday_end_hour,
"saturday_start_hour": a.saturday_start_hour,
"saturday_end_hour": a.saturday_end_hour,
"content": a.content,
"embedding": a.embedding,
}
for a in amenities
],
)
# If the table already exists, drop it to avoid conflicts
await conn.execute(text("DROP TABLE IF EXISTS flights CASCADE"))
# Create a new table
await conn.execute(
text(
"""
CREATE TABLE flights(
id INTEGER PRIMARY KEY,
airline TEXT,
flight_number TEXT,
departure_airport TEXT,
arrival_airport TEXT,
departure_time TIMESTAMP,
arrival_time TIMESTAMP,
departure_gate TEXT,
arrival_gate TEXT
)
"""
)
)
# Insert all the data
await conn.execute(
text(
"""
INSERT INTO flights VALUES (:id, :airline, :flight_number,
:departure_airport, :arrival_airport, :departure_time,
:arrival_time, :departure_gate, :arrival_gate)
"""
),
[
{
"id": f.id,
"airline": f.airline,
"flight_number": f.flight_number,
"departure_airport": f.departure_airport,
"arrival_airport": f.arrival_airport,
"departure_time": f.departure_time,
"arrival_time": f.arrival_time,
"departure_gate": f.departure_gate,
"arrival_gate": f.arrival_gate,
}
for f in flights
],
)
await conn.commit()
async def export_data(
self,
) -> tuple[list[models.Airport], list[models.Amenity], list[models.Flight]]:
async with self.__pool.connect() as conn:
airport_task = asyncio.create_task(
conn.execute(text("""SELECT * FROM airports"""))
)
amenity_task = asyncio.create_task(
conn.execute(text("""SELECT * FROM amenities"""))
)
flights_task = asyncio.create_task(
conn.execute(text("""SELECT * FROM flights"""))
)
airport_results = (await airport_task).mappings().fetchall()
amenity_results = (await amenity_task).mappings().fetchall()
flights_results = (await flights_task).mappings().fetchall()
airports = [models.Airport.model_validate(a) for a in airport_results]
amenities = [models.Amenity.model_validate(a) for a in amenity_results]
flights = [models.Flight.model_validate(f) for f in flights_results]
return airports, amenities, flights
async def get_airport_by_id(self, id: int) -> Optional[models.Airport]:
async with self.__pool.connect() as conn:
s = text("""SELECT * FROM airports WHERE id=:id""")
params = {"id": id}
result = (await conn.execute(s, params)).mappings().fetchone()
if result is None:
return None
res = models.Airport.model_validate(result)
return res
async def get_airport_by_iata(self, iata: str) -> Optional[models.Airport]:
async with self.__pool.connect() as conn:
s = text("""SELECT * FROM airports WHERE iata ILIKE :iata""")
params = {"iata": iata}
result = (await conn.execute(s, params)).mappings().fetchone()
if result is None:
return None
res = models.Airport.model_validate(result)
return res
async def search_airports(
self,
country: Optional[str] = None,
city: Optional[str] = None,
name: Optional[str] = None,
) -> list[models.Airport]:
async with self.__pool.connect() as conn:
s = text(
"""
SELECT * FROM airports
WHERE (CAST(:country AS TEXT) IS NULL OR country ILIKE :country)
AND (CAST(:city AS TEXT) IS NULL OR city ILIKE :city)
AND (CAST(:name AS TEXT) IS NULL OR name ILIKE '%' || :name || '%')
"""
)
params = {
"country": country,
"city": city,
"name": name,
}
results = (await conn.execute(s, params)).mappings().fetchall()
res = [models.Airport.model_validate(r) for r in results]
return res
async def get_amenity(self, id: int) -> Optional[models.Amenity]:
async with self.__pool.connect() as conn:
s = text(
"""
SELECT id, name, description, location, terminal, category, hour
FROM amenities WHERE id=:id
"""
)
params = {"id": id}
result = (await conn.execute(s, params)).mappings().fetchone()
if result is None:
return None
res = models.Amenity.model_validate(result)
return res
async def amenities_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> list[models.Amenity]:
async with self.__pool.connect() as conn:
s = text(
"""
SELECT id, name, description, location, terminal, category, hour
FROM (
SELECT id, name, description, location, terminal, category, hour,
1 - (embedding <=> :query_embedding) AS similarity
FROM amenities
WHERE 1 - (embedding <=> :query_embedding) > :similarity_threshold
ORDER BY similarity DESC
LIMIT :top_k
) AS sorted_amenities
"""
)
params = {
"query_embedding": query_embedding,
"similarity_threshold": similarity_threshold,
"top_k": top_k,
}
results = (await conn.execute(s, params)).mappings().fetchall()
res = [models.Amenity.model_validate(r) for r in results]
return res
async def get_flight(self, flight_id: int) -> Optional[models.Flight]:
async with self.__pool.connect() as conn:
s = text(
"""
SELECT * FROM flights
WHERE id = :flight_id
"""
)
params = {"flight_id": flight_id}
result = (await conn.execute(s, params)).mappings().fetchone()
if result is None:
return None
res = models.Flight.model_validate(result)
return res
async def search_flights_by_number(
self,
airline: str,
number: str,
) -> list[models.Flight]:
async with self.__pool.connect() as conn:
s = text(
"""
SELECT * FROM flights
WHERE airline = :airline
AND flight_number = :number
"""
)
params = {
"airline": airline,
"number": number,
}
results = (await conn.execute(s, params)).mappings().fetchall()
res = [models.Flight.model_validate(r) for r in results]
return res
async def search_flights_by_airports(
self,
date: str,
departure_airport: Optional[str] = None,
arrival_airport: Optional[str] = None,
) -> list[models.Flight]:
async with self.__pool.connect() as conn:
s = text(
"""
SELECT * FROM flights
WHERE (CAST(:departure_airport AS TEXT) IS NULL OR departure_airport ILIKE :departure_airport)
AND (CAST(:arrival_airport AS TEXT) IS NULL OR arrival_airport ILIKE :arrival_airport)
AND departure_time >= CAST(:datetime AS timestamp)
AND departure_time < CAST(:datetime AS timestamp) + interval '1 day'
"""
)
params = {
"departure_airport": departure_airport,
"arrival_airport": arrival_airport,
"datetime": datetime.strptime(date, "%Y-%m-%d"),
}
results = (await conn.execute(s, params)).mappings().fetchall()
res = [models.Flight.model_validate(r) for r in results]
return res
async def insert_ticket(
self,
user_id: str,
user_name: str,
user_email: str,
airline: str,
flight_number: str,
departure_airport: str,
arrival_airport: str,
departure_time: str,
arrival_time: str,
):
raise NotImplementedError("Not Implemented")
async def list_tickets(
self,
user_id: str,
) -> list[models.Ticket]:
raise NotImplementedError("Not Implemented")
async def close(self):
await self.__pool.dispose()