retrieval_service/datastore/providers/spanner_postgres.py (565 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime from typing import Any, Literal, Optional from google.cloud import spanner # type: ignore from google.cloud.spanner_v1 import JsonObject, param_types from google.cloud.spanner_v1.database import Database from google.cloud.spanner_v1.instance import Instance from google.oauth2 import service_account # type: ignore from pydantic import BaseModel import models from .. import datastore # Identifier for Spanner SPANNER_IDENTIFIER = "spanner-postgres" # Configuration model for Spanner class Config(BaseModel, datastore.AbstractConfig): """ Configuration model for Spanner. Attributes: kind (Literal["spanner"]): Type of datastore. project (str): Google Cloud project ID. instance (str): ID of the Spanner instance. database (str): ID of the Spanner database. service_account_key_file (str): Service Account Key File. """ kind: Literal["spanner-postgres"] project: str instance: str database: str service_account_key_file: Optional[str] = None # Client class for interacting with Spanner class Client(datastore.Client[Config]): OPERATION_TIMEOUT_SECONDS = 240 BATCH_SIZE = 1000 AIRPORT_COLUMNS = ["id", "iata", "name", "city", "country"] AMENITIES_COLUMNS = [ "id", "name", "description", "location", "terminal", "category", "hour", "sunday_start_hour", "sunday_end_hour", "monday_start_hour", "monday_end_hour", "tuesday_start_hour", "tuesday_end_hour", "wednesday_start_hour", "wednesday_end_hour", "thursday_start_hour", "thursday_end_hour", "friday_start_hour", "friday_end_hour", "saturday_start_hour", "saturday_end_hour", "content", "embedding", ] FLIGHTS_COLUMNS = [ "id", "airline", "flight_number", "departure_airport", "arrival_airport", "departure_time", "arrival_time", "departure_gate", "arrival_gate", ] POLICIES_COLUMNS = ["id", "content", "embedding"] """ Client class for interacting with Spanner. Attributes: __client (spanner.Client): Spanner client instance. __instance_id (str): ID of the Spanner instance. __database_id (str): ID of the Spanner database. __instance (Instance): Spanner instance. __database (Database): Spanner database. """ @datastore.classproperty def kind(cls): return SPANNER_IDENTIFIER def __init__(self, client: spanner.Client, instance_id: str, database_id: str): """ Initialize the Spanner client. Args: client (spanner.Client): Spanner client instance. instance_id (str): ID of the Spanner instance. database_id (str): ID of the Spanner database. """ self.__client = client self.__instance_id = instance_id self.__database_id = database_id self.__instance = self.__client.instance(self.__instance_id) self.__database = self.__instance.database(self.__database_id) @classmethod async def create(cls, config: Config) -> "Client": """ Create a Spanner client. Args: config (Config): Configuration for creating the client. Returns: Client: Initialized Spanner client. """ client: spanner.Client if config.service_account_key_file is not None: credentials = service_account.Credentials.from_service_account_file( config.service_account_key_file ) client = spanner.Client(project=config.project, credentials=credentials) else: client = spanner.Client(project=config.project) instance_id = config.instance instance = client.instance(instance_id) if not instance.exists(): raise Exception(f"Instance with id: {instance_id} doesn't exist.") database_id = config.database database = instance.database(database_id) if not database.exists(): raise Exception(f"Database with id: {database_id} doesn't exist.") return cls(client, instance_id, database_id) async def initialize_data( self, airports: list[models.Airport], amenities: list[models.Amenity], flights: list[models.Flight], policies: list[models.Policy], ) -> None: """ Initialize data in the Spanner database by creating tables and inserting records. Args: airports (list[models.Airport]): list of airports to be initialized. amenities (list[models.Amenity]): list of amenities to be initialized. flights (list[models.Flight]): list of flights to be initialized. policies (list[models.Policy]): list of policies to be initialized. Returns: None """ # Initialize a list to store Data Definition Language (DDL) statements ddl = [] # Create DDL statement to drop the 'airports' table if it exists ddl.append("DROP TABLE IF EXISTS airports") # Create DDL statement to create the 'airports' table ddl.append( """ CREATE TABLE airports( id BIGINT PRIMARY KEY, iata VARCHAR, name VARCHAR, city VARCHAR, country VARCHAR ) """ ) # Create DDL statement to drop the 'amenities' table if it exists ddl.append("DROP TABLE IF EXISTS amenities") # Create DDL statement to create the 'amenities' table ddl.append( """ CREATE TABLE amenities( id BIGINT PRIMARY KEY, name VARCHAR, description VARCHAR, location VARCHAR, terminal VARCHAR, category VARCHAR, hour VARCHAR, sunday_start_hour VARCHAR, sunday_end_hour VARCHAR, monday_start_hour VARCHAR, monday_end_hour VARCHAR, tuesday_start_hour VARCHAR, tuesday_end_hour VARCHAR, wednesday_start_hour VARCHAR, wednesday_end_hour VARCHAR, thursday_start_hour VARCHAR, thursday_end_hour VARCHAR, friday_start_hour VARCHAR, friday_end_hour VARCHAR, saturday_start_hour VARCHAR, saturday_end_hour VARCHAR, content VARCHAR NOT NULL, embedding FLOAT8[] NOT NULL ) """ ) # Create DDL statement to drop the 'flights' table if it exists ddl.append("DROP TABLE IF EXISTS flights") # Create DDL statement to create the 'flights' table ddl.append( """ CREATE TABLE flights( id BIGINT PRIMARY KEY, airline VARCHAR, flight_number VARCHAR, departure_airport VARCHAR, arrival_airport VARCHAR, departure_time VARCHAR(100), arrival_time VARCHAR(100), departure_gate VARCHAR, arrival_gate VARCHAR ) """ ) # Create DDL statement to drop the 'policies' table if it exists ddl.append("DROP TABLE IF EXISTS policies") # Create DDL statement to create the 'policies' table ddl.append( """ CREATE TABLE policies( id BIGINT PRIMARY KEY, content VARCHAR NOT NULL, embedding FLOAT8[] NOT NULL ) """ ) # Create DDL statement to drop the 'tickets' table if it exists ddl.append("DROP TABLE IF EXISTS tickets") # Create DDL statement to create the 'tickets' table ddl.append( """ CREATE TABLE tickets( user_id VARCHAR, user_name VARCHAR, user_email VARCHAR, airline VARCHAR, flight_number VARCHAR, departure_airport VARCHAR, arrival_airport VARCHAR, departure_time VARCHAR(100), arrival_time VARCHAR(100), PRIMARY KEY(user_id, airline, flight_number, departure_time) ) """ ) # Update the schema using DDL statements operation = self.__database.update_ddl(ddl) print("Waiting for schema update operation to complete...") operation.result(self.OPERATION_TIMEOUT_SECONDS) print("Schema update operation completed") # Insert data into 'airports' table using batch operation values = [ tuple(getattr(airport, field) for field in self.AIRPORT_COLUMNS) for airport in airports ] for i in range(0, len(values), self.BATCH_SIZE): records = values[i : i + self.BATCH_SIZE] with self.__database.batch() as batch: batch.insert( table="airports", columns=self.AIRPORT_COLUMNS, values=records, ) # Insert data into 'amenities' table using batch operation values = [ tuple( ( str(getattr(amenity, field)) if isinstance(getattr(amenity, field), datetime.time) else getattr(amenity, field) ) for field in self.AMENITIES_COLUMNS ) for amenity in amenities ] for i in range(0, len(values), self.BATCH_SIZE): records = values[i : i + self.BATCH_SIZE] with self.__database.batch() as batch: batch.insert( table="amenities", columns=self.AMENITIES_COLUMNS, values=records, ) # Insert data into 'flights' table using batch operation values = [ tuple( ( str(getattr(flight, field)) if isinstance(getattr(flight, field), datetime.datetime) else getattr(flight, field) ) for field in self.FLIGHTS_COLUMNS ) for flight in flights ] for i in range(0, len(values), self.BATCH_SIZE): records = values[i : i + self.BATCH_SIZE] with self.__database.batch() as batch: batch.insert( table="flights", columns=self.FLIGHTS_COLUMNS, values=records, ) # Insert data into 'policies' table using batch operation values = [ tuple(getattr(policy, field) for field in self.POLICIES_COLUMNS) for policy in policies ] for i in range(0, len(values), self.BATCH_SIZE): records = values[i : i + self.BATCH_SIZE] with self.__database.batch() as batch: batch.insert( table="policies", columns=self.POLICIES_COLUMNS, values=records, ) # Return None to indicate successful initialization return None async def export_data( self, ) -> tuple[ list[models.Airport], list[models.Amenity], list[models.Flight], list[models.Policy], ]: """ Export data from the Spanner database. Returns: tuple: A tuple containing lists of airports, amenities, flights, and policies. """ airports: list = [] amenities: list = [] flights: list = [] policies: list = [] try: with self.__database.snapshot() as snapshot: # Execute SQL queries to fetch data from respective tables airport_results = snapshot.execute_sql( "SELECT {} FROM airports ORDER BY id ASC".format( ",".join(self.AIRPORT_COLUMNS) ) ) except Exception as e: # Handle any exceptions, such as database connection errors print(f"Error occurred while fetch airports: {e}") # Return empty lists in case of error return airports, amenities, flights, policies # Convert query results to model instances using model_validate method airports = [ models.Airport.model_validate( {key: value for key, value in zip(self.AIRPORT_COLUMNS, a)} ) for a in airport_results ] try: with self.__database.snapshot() as snapshot: # Execute SQL queries to fetch data from respective tables amenity_results = snapshot.execute_sql( "SELECT {} FROM amenities ORDER BY id ASC".format( ",".join(self.AMENITIES_COLUMNS) ) ) except Exception as e: # Handle any exceptions, such as database connection errors print(f"Error occurred while fetch amenities: {e}") # Return empty lists in case of error return airports, amenities, flights, policies # Convert query results to model instances using model_validate method amenities = [ models.Amenity.model_validate( {key: value for key, value in zip(self.AMENITIES_COLUMNS, a)} ) for a in amenity_results ] try: with self.__database.snapshot() as snapshot: # Execute SQL queries to fetch data from respective tables flights_results = snapshot.execute_sql( "SELECT {} FROM flights ORDER BY id ASC".format( ",".join(self.FLIGHTS_COLUMNS) ) ) except Exception as e: # Handle any exceptions, such as database connection errors print(f"Error occurred while fetch flights: {e}") # Return empty lists in case of error return airports, amenities, flights, policies # Convert query results to model instances using model_validate method flights = [ models.Flight.model_validate( {key: value for key, value in zip(self.FLIGHTS_COLUMNS, a)} ) for a in flights_results ] try: with self.__database.snapshot() as snapshot: # Execute SQL queries to fetch data from respective tables policy_results = snapshot.execute_sql( "SELECT {} FROM policies ORDER BY id ASC".format( ",".join(self.POLICIES_COLUMNS) ) ) except Exception as e: # Handle any exceptions, such as database connection errors print(f"Error occurred while fetch policies: {e}") # Return empty lists in case of error return airports, amenities, flights, policies # Convert query results to model instances using model_validate method policies = [ models.Policy.model_validate( {key: value for key, value in zip(self.POLICIES_COLUMNS, a)} ) for a in policy_results ] return airports, amenities, flights, policies async def get_airport_by_id( self, id: int ) -> tuple[Optional[models.Airport], Optional[str]]: """ Retrieve an airport by its ID. Args: id (int): The ID of the airport. Returns: Optional[models.Airport]: An Airport model instance if found, else None. """ with self.__database.snapshot() as snapshot: # Execute SQL query to fetch airport by ID result = snapshot.execute_sql( sql="SELECT * FROM airports WHERE id = $1", params={"p1": id}, param_types={"p1": param_types.INT64}, ) # Check if result is None if result is None: return None, None # Convert query result to model instance using model_validate method airports = [ models.Airport.model_validate( {key: value for key, value in zip(self.AIRPORT_COLUMNS, a)} ) for a in result ] return airports[0], None async def get_airport_by_iata( self, iata: str ) -> tuple[Optional[models.Airport], Optional[str]]: """ Retrieve an airport by its IATA code. Args: iata (str): The IATA code of the airport. Returns: Optional[models.Airport]: An Airport model instance if found, else None. """ with self.__database.snapshot() as snapshot: # Execute SQL query to fetch airport by ID result = snapshot.execute_sql( sql="SELECT * FROM airports WHERE LOWER(iata) LIKE LOWER($1)", params={"p1": iata}, param_types={"p1": param_types.STRING}, ) # Check if result is None if result is None: return None, None # Convert query result to model instance using model_validate method airports = [ models.Airport.model_validate( {key: value for key, value in zip(self.AIRPORT_COLUMNS, a)} ) for a in result ] return airports[0], None async def search_airports( self, country: Optional[str] = None, city: Optional[str] = None, name: Optional[str] = None, ) -> tuple[list[models.Airport], Optional[str]]: """ Search for airports based on optional parameters. Args: country (Optional[str]): The country of the airport. city (Optional[str]): The city of the airport. name (Optional[str]): The name of the airport. Returns: list[models.Airport]: A list of Airport model instances matching the search criteria. """ with self.__database.snapshot() as snapshot: # Construct SQL query based on provided parameters query = """ SELECT * FROM airports WHERE ($1 IS NULL OR LOWER(country) LIKE LOWER($1)) AND ($2 IS NULL OR LOWER(city) LIKE LOWER($2)) AND ($3 IS NULL OR LOWER(name) LIKE '%' || LOWER($3) || '%') """ # Execute SQL query with parameters results = snapshot.execute_sql( sql=query, params={ "p1": country, "p2": city, "p3": name, }, param_types={ "p1": param_types.STRING, "p2": param_types.STRING, "p3": param_types.STRING, }, ) # Convert query result to model instance using model_validate method airports = [ models.Airport.model_validate( {key: value for key, value in zip(self.AIRPORT_COLUMNS, a)} ) for a in results ] return airports, None async def get_amenity( self, id: int ) -> tuple[Optional[models.Amenity], Optional[str]]: """ Retrieves an amenity by its ID. Args: id (int): The ID of the amenity. Returns: Optional[models.Amenity]: An Amenity model instance if found, else None. """ with self.__database.snapshot() as snapshot: # Spread SQL query for readability result = snapshot.execute_sql( sql=""" SELECT id, name, description, location, terminal, category, hour FROM amenities WHERE id = $1 """, params={"p1": id}, param_types={"p1": param_types.INT64}, ) # Check if result is None if result is None: return None, None # Convert query result to model instance using model_validate method amenities = [ models.Amenity.model_validate( {key: value for key, value in zip(self.AMENITIES_COLUMNS, a)} ) for a in result ] return amenities[0], None async def amenities_search( self, query_embedding: list[float], similarity_threshold: float, top_k: int ) -> tuple[list[Any], Optional[str]]: """ Search for amenities based on similarity to a query embedding. Args: query_embedding (list[float]): The embedding representing the query. similarity_threshold (float): The minimum similarity threshold for results. top_k (int): The maximum number of results to return. Returns: list[models.Amenity]: A list of Amenity model instances matching the search criteria. """ with self.__database.snapshot() as snapshot: # Spread SQL query for readability query = """ SELECT name, description, location, terminal, category, hour FROM ( SELECT name, description, location, terminal, category, hour, spanner.cosine_distance(embedding, $1) AS similarity FROM amenities ) AS sorted_amenities WHERE (1 - similarity) > $2 ORDER BY similarity LIMIT $3 """ # Execute SQL query with parameters results = snapshot.execute_sql( sql=query, params={ "p1": query_embedding, "p2": similarity_threshold, "p3": top_k, }, param_types={ "p1": param_types.Array(param_types.FLOAT64), "p2": param_types.FLOAT64, "p3": param_types.INT64, }, ) # Convert query result to model instance using model_validate method amenities = [ {key: value for key, value in zip(self.AMENITIES_COLUMNS[1:], a)} for a in results ] return amenities, None async def get_flight( self, flight_id: int ) -> tuple[Optional[models.Flight], Optional[str]]: """ Retrieves a flight by its ID. Args: flight_id (int): The ID of the flight. Returns: Optional[models.Flight]: A Flight model instance if found, else None. """ with self.__database.snapshot() as snapshot: # Spread SQL query for readability result = snapshot.execute_sql( sql=""" SELECT * FROM flights WHERE id = $1 """, params={"p1": flight_id}, param_types={"p1": param_types.INT64}, ) # Check if result is None if result is None: return None, None # Convert query result to model instance using model_validate method flights = [ models.Flight.model_validate( {key: value for key, value in zip(self.FLIGHTS_COLUMNS, a)} ) for a in result ] return flights[0], None async def search_flights_by_number( self, airline: str, number: str, ) -> tuple[list[models.Flight], Optional[str]]: """ Search for flights by airline and flight number. Args: airline (str): The airline of the flight. number (str): The flight number. Returns: list[models.Flight]: A list of Flight model instances matching the search criteria. """ with self.__database.snapshot() as snapshot: # Spread SQL query for readability results = snapshot.execute_sql( sql=""" SELECT * FROM flights WHERE airline = $1 AND flight_number = $2 LIMIT 10 """, params={"p1": airline, "p2": number}, param_types={ "p1": param_types.STRING, "p2": param_types.STRING, }, ) # Convert query result to model instance using model_validate method flights = [ models.Flight.model_validate( {key: value for key, value in zip(self.FLIGHTS_COLUMNS, a)} ) for a in results ] return flights, None async def search_flights_by_airports( self, date: str, departure_airport: Optional[str] = None, arrival_airport: Optional[str] = None, ) -> tuple[list[models.Flight], Optional[str]]: """ Search for flights by departure and/or arrival airports. Args: date (str): The date of the flights in 'YYYY-MM-DD' format. departure_airport (str, optional): The departure airport code. Defaults to None. arrival_airport (str, optional): The arrival airport code. Defaults to None. Returns: list[models.Flight]: A list of Flight model instances matching the search criteria. """ with self.__database.snapshot() as snapshot: # Spread SQL query for readability query = """ SELECT * FROM flights WHERE (COALESCE($1) IS NULL OR LOWER(departure_airport) LIKE LOWER($1)) AND (COALESCE($2) IS NULL OR LOWER(arrival_airport) LIKE LOWER($2)) AND CAST(departure_time as timestamptz) >= CAST($3 AS timestamptz) AND cast(departure_time as timestamptz) < spanner.timestamptz_add(CAST($3 AS timestamptz), '1 day') LIMIT 10 """ # Execute SQL query with parameters results = snapshot.execute_sql( sql=query, params={ "p1": departure_airport, "p2": arrival_airport, "p3": date, }, param_types={ "p1": param_types.STRING, "p2": param_types.STRING, "p3": param_types.STRING, }, ) # Convert query results to model instances using model_validate method flights = [ models.Flight.model_validate( {key: value for key, value in zip(self.FLIGHTS_COLUMNS, a)} ) for a in results ] return flights, None async def validate_ticket( self, airline: str, flight_number: str, departure_airport: str, departure_time: str, ) -> tuple[Optional[models.Flight], Optional[str]]: with self.__database.snapshot() as snapshot: # Spread SQL query for readability results = snapshot.execute_sql( sql=""" SELECT * FROM flights WHERE LOWER(airline) LIKE LOWER($1) AND LOWER(flight_number) LIKE LOWER($2) AND LOWER(departure_airport) LIKE LOWER($3) AND departure_time = $4 """, params={ "p1": airline, "p2": flight_number, "p3": departure_airport, "p4": departure_time, }, param_types={ "p1": param_types.STRING, "p2": param_types.STRING, "p3": param_types.STRING, "p4": param_types.STRING, }, ) if results is None: return None, None flights = [ models.Flight.model_validate( {key: value for key, value in zip(self.FLIGHTS_COLUMNS, a)} ) for a in results ] if not flights: return None, None return flights[0], None async def insert_ticket( self, user_id: str, user_name: str, user_email: str, airline: str, flight_number: str, departure_airport: str, arrival_airport: str, departure_time: str, arrival_time: str, ): """ Inserts a ticket into the database. Args: user_id (str): The ID of the user. user_name (str): The name of the user. user_email (str): The email of the user. airline (str): The airline of the flight. flight_number (str): The flight number. departure_airport (str): The departure airport code. arrival_airport (str): The arrival airport code. departure_time (str): The departure time of the flight. arrival_time (str): The arrival time of the flight. """ departure_time_datetime = datetime.datetime.strptime( departure_time, "%Y-%m-%d %H:%M:%S" ) arrival_time_datetime = datetime.datetime.strptime( arrival_time, "%Y-%m-%d %H:%M:%S" ) with self.__database.batch() as batch: batch.insert( table="tickets", columns=[ "user_id", "user_name", "user_email", "airline", "flight_number", "departure_airport", "arrival_airport", "departure_time", "arrival_time", ], values=[ [ user_id, user_name, user_email, airline, flight_number, departure_airport, arrival_airport, departure_time_datetime, arrival_time_datetime, ] ], ) async def list_tickets( self, user_id: str, ) -> tuple[list[Any], Optional[str]]: """ Retrieves a list of tickets for a user. Args: user_id (str): The ID of the user. """ with self.__database.snapshot() as snapshot: # Spread SQL query for readability results = snapshot.execute_sql( sql=""" SELECT user_name, airline, flight_number, departure_airport, arrival_airport, departure_time, arrival_time FROM tickets WHERE user_id = $1 """, params={"p1": user_id}, param_types={"p1": param_types.STRING}, ) # Convert query results to model instances using model_validate method tickets = [ models.Ticket.model_validate( { key: value for key, value in zip( [ "user_id", "user_name", "user_email", "airline", "flight_number", "departure_airport", "arrival_airport", "departure_time", "arrival_time", ], a, ) } ) for a in results ] return tickets, None async def policies_search( self, query_embedding: list[float], similarity_threshold: float, top_k: int ) -> tuple[list[str], Optional[str]]: """ Search for policies based on similarity to a query embedding. Args: query_embedding (list[float]): The embedding representing the query. similarity_threshold (float): The minimum similarity threshold for results. top_k (int): The maximum number of results to return. Returns: list[models.Policy]: A list of Policy model instances matching the search criteria. """ with self.__database.snapshot() as snapshot: query = """ SELECT content FROM ( SELECT content, spanner.cosine_distance(embedding, $1) AS similarity FROM policies ) AS sorted_policies WHERE (1 - similarity) > $2 ORDER BY similarity LIMIT $3 """ # Execute SQL query with parameters results = snapshot.execute_sql( sql=query, params={ "p1": query_embedding, "p2": similarity_threshold, "p3": top_k, }, param_types={ "p1": param_types.Array(param_types.FLOAT64), "p2": param_types.FLOAT64, "p3": param_types.INT64, }, ) # Convert query result to model instance using model_validate method policies = [a[0] for a in results] return policies, None async def close(self): """ Closes the database client connection. """ self.__client.close()