retrieval_service/datastore/providers/cloudsql_mysql.py (535 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from datetime import datetime
from typing import Any, Literal, Optional
import pymysql
from google.cloud.sql.connector import Connector, RefreshStrategy
from pydantic import BaseModel
from sqlalchemy import Engine, create_engine, text
from sqlalchemy.engine.base import Engine
import models
from .. import datastore
MYSQL_IDENTIFIER = "cloudsql-mysql"
class Config(BaseModel, datastore.AbstractConfig):
kind: Literal["cloudsql-mysql"]
project: str
region: str
instance: str
user: str
password: str
database: str
class Client(datastore.Client[Config]):
__pool: Engine
__db_name: str
__connector: Optional[Connector] = None
@datastore.classproperty
def kind(cls):
return MYSQL_IDENTIFIER
def __init__(self, pool: Engine, db_name: str):
self.__pool = pool
self.__db_name = db_name
@classmethod
def create_sync(cls, config: Config) -> "Client":
def getconn() -> pymysql.Connection:
if cls.__connector is None:
cls.__connector = Connector(refresh_strategy=RefreshStrategy.LAZY)
return cls.__connector.connect(
# Cloud SQL instance connection name
f"{config.project}:{config.region}:{config.instance}",
"pymysql",
user=f"{config.user}",
password=f"{config.password}",
db=f"{config.database}",
autocommit=True,
)
pool = create_engine(
"mysql+pymysql://",
creator=getconn,
)
if pool is None:
raise TypeError("pool not instantiated")
return cls(pool, config.database)
@classmethod
async def create(cls, config: Config) -> "Client":
loop = asyncio.get_running_loop()
pool = await loop.run_in_executor(None, cls.create_sync, config)
return pool
def initialize_data_sync(
self,
airports: list[models.Airport],
amenities: list[models.Amenity],
flights: list[models.Flight],
policies: list[models.Policy],
) -> None:
with self.__pool.connect() as conn:
# If the table already exists, drop it to avoid conflicts
conn.execute(text("DROP TABLE IF EXISTS airports"))
# Create a new table
conn.execute(
text(
"""
CREATE TABLE airports(
id INT PRIMARY KEY,
iata TEXT,
name TEXT,
city TEXT,
country TEXT
)
"""
)
)
# Insert all the data
conn.execute(
text(
"""INSERT INTO airports VALUES (:id, :iata, :name, :city, :country)"""
),
parameters=[
{
"id": a.id,
"iata": a.iata,
"name": a.name,
"city": a.city,
"country": a.country,
}
for a in airports
],
)
# If the table already exists, drop it to avoid conflicts
conn.execute(text("DROP TABLE IF EXISTS amenities CASCADE"))
# Create a new table
conn.execute(
text(
"""
CREATE TABLE amenities(
id INT PRIMARY KEY,
name TEXT,
description TEXT,
location TEXT,
terminal TEXT,
category TEXT,
hour TEXT,
sunday_start_hour TIME,
sunday_end_hour TIME,
monday_start_hour TIME,
monday_end_hour TIME,
tuesday_start_hour TIME,
tuesday_end_hour TIME,
wednesday_start_hour TIME,
wednesday_end_hour TIME,
thursday_start_hour TIME,
thursday_end_hour TIME,
friday_start_hour TIME,
friday_end_hour TIME,
saturday_start_hour TIME,
saturday_end_hour TIME,
content TEXT NOT NULL,
embedding vector(768) USING VARBINARY NOT NULL
)
"""
)
)
# Insert all the data
conn.execute(
text(
"""
INSERT INTO amenities VALUES (:id, :name, :description, :location,
:terminal, :category, :hour, :sunday_start_hour, :sunday_end_hour,
:monday_start_hour, :monday_end_hour, :tuesday_start_hour,
:tuesday_end_hour, :wednesday_start_hour, :wednesday_end_hour,
:thursday_start_hour, :thursday_end_hour, :friday_start_hour,
:friday_end_hour, :saturday_start_hour, :saturday_end_hour, :content, string_to_vector(:embedding))
"""
),
parameters=[
{
"id": a.id,
"name": a.name,
"description": a.description,
"location": a.location,
"terminal": a.terminal,
"category": a.category,
"hour": a.hour,
"sunday_start_hour": a.sunday_start_hour,
"sunday_end_hour": a.sunday_end_hour,
"monday_start_hour": a.monday_start_hour,
"monday_end_hour": a.monday_end_hour,
"tuesday_start_hour": a.tuesday_start_hour,
"tuesday_end_hour": a.tuesday_end_hour,
"wednesday_start_hour": a.wednesday_start_hour,
"wednesday_end_hour": a.wednesday_end_hour,
"thursday_start_hour": a.thursday_start_hour,
"thursday_end_hour": a.thursday_end_hour,
"friday_start_hour": a.friday_start_hour,
"friday_end_hour": a.friday_end_hour,
"saturday_start_hour": a.saturday_start_hour,
"saturday_end_hour": a.saturday_end_hour,
"content": a.content,
"embedding": f"{a.embedding}",
}
for a in amenities
],
)
# If the table already exists, drop it to avoid conflicts
conn.execute(text("DROP TABLE IF EXISTS flights"))
# Create a new table
conn.execute(
text(
"""
CREATE TABLE flights(
id INTEGER PRIMARY KEY,
airline TEXT,
flight_number TEXT,
departure_airport TEXT,
arrival_airport TEXT,
departure_time TIMESTAMP,
arrival_time TIMESTAMP,
departure_gate TEXT,
arrival_gate TEXT
)
"""
)
)
# Insert all the data
conn.execute(
text(
"""
INSERT INTO flights VALUES (:id, :airline, :flight_number,
:departure_airport, :arrival_airport, :departure_time,
:arrival_time, :departure_gate, :arrival_gate)
"""
),
parameters=[
{
"id": f.id,
"airline": f.airline,
"flight_number": f.flight_number,
"departure_airport": f.departure_airport,
"arrival_airport": f.arrival_airport,
"departure_time": f.departure_time,
"arrival_time": f.arrival_time,
"departure_gate": f.departure_gate,
"arrival_gate": f.arrival_gate,
}
for f in flights
],
)
# If the table already exists, drop it to avoid conflicts
conn.execute(text("DROP TABLE IF EXISTS tickets"))
# Create a new table
conn.execute(
text(
"""
CREATE TABLE tickets(
user_id TEXT,
user_name TEXT,
user_email TEXT,
airline TEXT,
flight_number TEXT,
departure_airport TEXT,
arrival_airport TEXT,
departure_time TIMESTAMP,
arrival_time TIMESTAMP
)
"""
)
)
# If the table already exists, drop it to avoid conflicts
conn.execute(text("DROP TABLE IF EXISTS policies"))
# Create a new table
conn.execute(
text(
"""
CREATE TABLE policies(
id INT PRIMARY KEY,
content TEXT NOT NULL,
embedding vector(768) USING VARBINARY NOT NULL
)
"""
)
)
# Insert all the data
conn.execute(
text(
"""
INSERT INTO policies VALUES (:id, :content, string_to_vector(:embedding))
"""
),
parameters=[
{
"id": p.id,
"content": p.content,
"embedding": f"{p.embedding}",
}
for p in policies
],
)
async def initialize_data(
self,
airports: list[models.Airport],
amenities: list[models.Amenity],
flights: list[models.Flight],
policies: list[models.Policy],
) -> None:
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None, self.initialize_data_sync, airports, amenities, flights, policies
)
def export_data_sync(
self,
) -> tuple[
list[models.Airport],
list[models.Amenity],
list[models.Flight],
list[models.Policy],
]:
with self.__pool.connect() as conn:
airport_task = conn.execute(
text("""SELECT * FROM airports ORDER BY id ASC""")
)
amenity_task = conn.execute(
text(
"""
SELECT id,
name,
description,
location,
terminal,
category,
hour,
DATE_FORMAT(sunday_start_hour, '%H:%i') AS sunday_start_hour,
DATE_FORMAT(sunday_end_hour, '%H:%i') AS sunday_end_hour,
DATE_FORMAT(monday_start_hour, '%H:%i') AS monday_start_hour,
DATE_FORMAT(monday_end_hour, '%H:%i') AS monday_end_hour,
DATE_FORMAT(tuesday_start_hour, '%H:%i') AS tuesday_start_hour,
DATE_FORMAT(tuesday_end_hour, '%H:%i') AS tuesday_end_hour,
DATE_FORMAT(wednesday_start_hour, '%H:%i') AS wednesday_start_hour,
DATE_FORMAT(wednesday_end_hour, '%H:%i') AS wednesday_end_hour,
DATE_FORMAT(thursday_start_hour, '%H:%i') AS thursday_start_hour,
DATE_FORMAT(thursday_end_hour, '%H:%i') AS thursday_end_hour,
DATE_FORMAT(friday_start_hour, '%H:%i') AS friday_start_hour,
DATE_FORMAT(friday_end_hour, '%H:%i') AS friday_end_hour,
DATE_FORMAT(saturday_start_hour, '%H:%i') AS saturday_start_hour,
DATE_FORMAT(saturday_end_hour, '%H:%i') AS saturday_end_hour,
content,
vector_to_string(embedding) as embedding
FROM amenities ORDER BY id ASC
"""
)
)
flights_task = conn.execute(
text("""SELECT * FROM flights ORDER BY id ASC""")
)
policy_task = conn.execute(
text(
"""SELECT id, content, vector_to_string(embedding) as embedding FROM policies ORDER BY id ASC"""
)
)
airport_results = (airport_task).mappings().fetchall()
amenity_results = (amenity_task).mappings().fetchall()
flights_results = (flights_task).mappings().fetchall()
policy_results = (policy_task).mappings().fetchall()
airports = [models.Airport.model_validate(a) for a in airport_results]
amenities = [models.Amenity.model_validate(a) for a in amenity_results]
flights = [models.Flight.model_validate(f) for f in flights_results]
policies = [models.Policy.model_validate(p) for p in policy_results]
return airports, amenities, flights, policies
async def export_data(
self,
) -> tuple[
list[models.Airport],
list[models.Amenity],
list[models.Flight],
list[models.Policy],
]:
loop = asyncio.get_running_loop()
res = await loop.run_in_executor(None, self.export_data_sync)
return res
def get_airport_by_id_sync(
self, id: int
) -> tuple[Optional[models.Airport], Optional[str]]:
with self.__pool.connect() as conn:
s = text("""SELECT * FROM airports WHERE id=:id""")
params = {"id": id}
result = (conn.execute(s, params)).mappings().fetchone()
if result is None:
return None, None
res = models.Airport.model_validate(result)
return res, None
async def get_airport_by_id(
self, id: int
) -> tuple[Optional[models.Airport], Optional[str]]:
loop = asyncio.get_running_loop()
res, sql = await loop.run_in_executor(None, self.get_airport_by_id_sync, id)
return res, sql
def get_airport_by_iata_sync(
self, iata: str
) -> tuple[Optional[models.Airport], Optional[str]]:
with self.__pool.connect() as conn:
s = text("""SELECT * FROM airports WHERE LOWER(iata) LIKE LOWER(:iata)""")
params = {"iata": iata}
result = (conn.execute(s, params)).mappings().fetchone()
if result is None:
return None, None
res = models.Airport.model_validate(result)
return res, None
async def get_airport_by_iata(
self, iata: str
) -> tuple[Optional[models.Airport], Optional[str]]:
loop = asyncio.get_running_loop()
res, sql = await loop.run_in_executor(None, self.get_airport_by_iata_sync, iata)
return res, sql
def search_airports_sync(
self,
country: Optional[str] = None,
city: Optional[str] = None,
name: Optional[str] = None,
) -> tuple[list[models.Airport], Optional[str]]:
with self.__pool.connect() as conn:
s = text(
"""
SELECT * FROM airports
WHERE (:country IS NULL OR LOWER(country) LIKE CONCAT('%', LOWER(:country), '%'))
AND (:city IS NULL OR LOWER(city) LIKE CONCAT('%', LOWER(:city), '%'))
AND (:name IS NULL OR LOWER(name) LIKE CONCAT('%', LOWER(:name), '%'))
LIMIT 10;
"""
)
params = {
"country": country,
"city": city,
"name": name,
}
results = (conn.execute(s, parameters=params)).mappings().fetchall()
res = [models.Airport.model_validate(r) for r in results]
return res, None
async def search_airports(
self,
country: Optional[str] = None,
city: Optional[str] = None,
name: Optional[str] = None,
) -> tuple[list[models.Airport], Optional[str]]:
loop = asyncio.get_running_loop()
res, sql = await loop.run_in_executor(
None, self.search_airports_sync, country, city, name
)
return res, sql
def get_amenity_sync(
self, id: int
) -> tuple[Optional[models.Amenity], Optional[str]]:
with self.__pool.connect() as conn:
s = text(
"""
SELECT id, name, description, location, terminal, category, hour
FROM amenities WHERE id=:id
"""
)
params = {"id": id}
result = (conn.execute(s, parameters=params)).mappings().fetchone()
if result is None:
return None, None
res = models.Amenity.model_validate(result)
return res, None
async def get_amenity(
self, id: int
) -> tuple[Optional[models.Amenity], Optional[str]]:
loop = asyncio.get_running_loop()
res, sql = await loop.run_in_executor(None, self.get_amenity_sync, id)
return res, sql
def amenities_search_sync(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> tuple[list[Any], Optional[str]]:
with self.__pool.connect() as conn:
s = text(
"""
SELECT name, description, location, terminal, category, hour
FROM amenities
ORDER BY APPROX_DISTANCE(embedding, string_to_vector(:query), 'distance_measure=cosine') LIMIT :search_options
"""
)
params = {
"query": f"{query_embedding}",
"search_options": top_k,
}
results = (conn.execute(s, parameters=params)).mappings().fetchall()
res = [r for r in results]
return res, None
async def amenities_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> tuple[list[Any], Optional[str]]:
loop = asyncio.get_running_loop()
res, sql = await loop.run_in_executor(
None,
self.amenities_search_sync,
query_embedding,
similarity_threshold,
top_k,
)
return res, sql
def get_flight_sync(
self, flight_id: int
) -> tuple[Optional[models.Flight], Optional[str]]:
with self.__pool.connect() as conn:
s = text(
"""
SELECT * FROM flights
WHERE id = :flight_id
"""
)
params = {"flight_id": flight_id}
result = (conn.execute(s, parameters=params)).mappings().fetchone()
if result is None:
return None, None
res = models.Flight.model_validate(result)
return res, None
async def get_flight(
self, flight_id: int
) -> tuple[Optional[models.Flight], Optional[str]]:
loop = asyncio.get_running_loop()
res, sql = await loop.run_in_executor(None, self.get_flight_sync, flight_id)
return res, sql
def search_flights_by_number_sync(
self,
airline: str,
number: str,
) -> tuple[list[models.Flight], Optional[str]]:
with self.__pool.connect() as conn:
s = text(
"""
SELECT * FROM flights
WHERE airline = :airline
AND flight_number = :number
LIMIT 10
"""
)
params = {
"airline": airline,
"number": number,
}
results = (conn.execute(s, parameters=params)).mappings().fetchall()
res = [models.Flight.model_validate(r) for r in results]
return res, None
async def search_flights_by_number(
self,
airline: str,
number: str,
) -> tuple[list[models.Flight], Optional[str]]:
loop = asyncio.get_running_loop()
res, sql = await loop.run_in_executor(
None, self.search_flights_by_number_sync, airline, number
)
return res, sql
def search_flights_by_airports_sync(
self,
date: str,
departure_airport: Optional[str] = None,
arrival_airport: Optional[str] = None,
) -> tuple[list[models.Flight], Optional[str]]:
with self.__pool.connect() as conn:
s = text(
"""
SELECT * FROM flights
WHERE (CAST(:departure_airport AS CHAR(255)) IS NULL OR LOWER(departure_airport) LIKE LOWER(:departure_airport))
AND (CAST(:arrival_airport AS CHAR(255)) IS NULL OR LOWER(arrival_airport) LIKE LOWER(:arrival_airport))
AND departure_time >= CAST(:datetime AS DATETIME)
AND (departure_time < DATE_ADD(CAST(:datetime AS DATETIME), interval 1 day))
LIMIT 10
"""
)
params = {
"departure_airport": departure_airport,
"arrival_airport": arrival_airport,
"datetime": datetime.strptime(date, "%Y-%m-%d"),
}
results = (conn.execute(s, parameters=params)).mappings().fetchall()
res = [models.Flight.model_validate(r) for r in results]
return res, None
async def search_flights_by_airports(
self,
date: str,
departure_airport: Optional[str] = None,
arrival_airport: Optional[str] = None,
) -> tuple[list[models.Flight], Optional[str]]:
loop = asyncio.get_running_loop()
res, sql = await loop.run_in_executor(
None,
self.search_flights_by_airports_sync,
date,
departure_airport,
arrival_airport,
)
return res, sql
def validate_ticket_sync(
self,
airline: str,
flight_number: str,
departure_airport: str,
departure_time: str,
) -> tuple[Optional[models.Flight], Optional[str]]:
with self.__pool.connect() as conn:
s = text(
"""
SELECT * FROM flights
WHERE LOWER(airline) LIKE LOWER(:airline)
AND LOWER(flight_number) LIKE LOWER(:flight_number)
AND LOWER(departure_airport) LIKE LOWER(:departure_airport)
AND departure_time = CAST(:departure_time AS DATETIME)
LIMIT 10
"""
)
params = {
"airline": airline,
"flight_number": flight_number,
"departure_airport": departure_airport,
"departure_time": departure_time,
}
result = (conn.execute(s, parameters=params)).mappings().fetchone()
if result is None:
return None, None
res = models.Flight.model_validate(result)
return res, None
async def validate_ticket(
self,
airline: str,
flight_number: str,
departure_airport: str,
departure_time: str,
) -> tuple[Optional[models.Flight], Optional[str]]:
loop = asyncio.get_running_loop()
res, sql = await loop.run_in_executor(
None,
self.validate_ticket_sync,
airline,
flight_number,
departure_airport,
departure_time,
)
return res, sql
def insert_ticket_sync(
self,
user_id: str,
user_name: str,
user_email: str,
airline: str,
flight_number: str,
departure_airport: str,
arrival_airport: str,
departure_time: str,
arrival_time: str,
):
with self.__pool.connect() as conn:
s = text(
"""
INSERT INTO tickets (
user_id,
user_name,
user_email,
airline,
flight_number,
departure_airport,
arrival_airport,
departure_time,
arrival_time
) VALUES (
:user_id,
:user_name,
:user_email,
:airline,
:flight_number,
:departure_airport,
:arrival_airport,
:departure_time,
:arrival_time
);
"""
)
params = {
"user_id": user_id,
"user_name": user_name,
"user_email": user_email,
"airline": airline,
"flight_number": flight_number,
"departure_airport": departure_airport,
"arrival_airport": arrival_airport,
"departure_time": departure_time,
"arrival_time": arrival_time,
}
conn.execute(s, params).mappings()
async def insert_ticket(
self,
user_id: str,
user_name: str,
user_email: str,
airline: str,
flight_number: str,
departure_airport: str,
arrival_airport: str,
departure_time: str,
arrival_time: str,
):
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None,
self.insert_ticket_sync,
user_id,
user_name,
user_email,
airline,
flight_number,
departure_airport,
arrival_airport,
departure_time,
arrival_time,
)
def list_tickets_sync(
self,
user_id: str,
) -> tuple[list[Any], Optional[str]]:
with self.__pool.connect() as conn:
s = text(
"""
SELECT user_name, airline, flight_number, departure_airport, arrival_airport, departure_time, arrival_time FROM tickets
WHERE user_id = :user_id
"""
)
params = {
"user_id": user_id,
}
results = (conn.execute(s, parameters=params)).mappings().fetchall()
res = [r for r in results]
return res, None
async def list_tickets(
self,
user_id: str,
) -> tuple[list[models.Ticket], Optional[str]]:
loop = asyncio.get_running_loop()
res, sql = await loop.run_in_executor(None, self.list_tickets_sync, user_id)
return res, sql
def policies_search_sync(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> tuple[list[str], Optional[str]]:
with self.__pool.connect() as conn:
s = text(
"""
SELECT content
FROM policies
ORDER BY APPROX_DISTANCE(embedding, string_to_vector(:query), 'distance_measure=cosine') LIMIT :search_options
"""
)
params = {
"query": f"{query_embedding}",
"search_options": top_k,
}
results = (conn.execute(s, parameters=params)).mappings().fetchall()
res = [r["content"] for r in results]
return res, None
async def policies_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> tuple[list[str], Optional[str]]:
loop = asyncio.get_running_loop()
res, sql = await loop.run_in_executor(
None,
self.policies_search_sync,
query_embedding,
similarity_threshold,
top_k,
)
return res, sql
async def close(self):
self.__pool.dispose()