dbconnectors/PgConnector.py (406 lines of code) (raw):
"""
PostgreSQL Connector Class
"""
import asyncpg
from google.cloud.sql.connector import Connector
from sqlalchemy import create_engine
import pandas as pd
from sqlalchemy.sql import text
from pgvector.asyncpg import register_vector
import asyncio
from pg8000.exceptions import DatabaseError
from utilities import root_dir
from google.cloud.sql.connector import Connector
from dbconnectors import DBConnector
from abc import ABC
def pg_specific_data_types():
return '''
PostgreSQL offers a wide variety of datatypes to store different types of data effectively. Here's a breakdown of the available categories:
Numeric datatypes -
SMALLINT: Stores small-range integers between -32768 and 32767.
INTEGER: Stores typical integers between -2147483648 and 2147483647.
BIGINT: Stores large-range integers between -9223372036854775808 and 9223372036854775807.
DECIMAL(p,s): Stores arbitrary precision numbers with a maximum of p digits and s digits to the right of the decimal point.
NUMERIC: Similar to DECIMAL but with additional features like automatic scaling.
REAL: Stores single-precision floating-point numbers with an approximate range of -3.4E+38 to 3.4E+38.
DOUBLE PRECISION: Stores double-precision floating-point numbers with an approximate range of -1.7E+308 to 1.7E+308.
Character datatypes -
CHAR(n): Fixed-length character string with a specified length of n characters.
VARCHAR(n): Variable-length character string with a maximum length of n characters.
TEXT: Variable-length string with no maximum size limit.
CHARACTER VARYING(n): Alias for VARCHAR(n).
CHARACTER: Alias for CHAR.
Monetary datatypes -
MONEY: Stores monetary amounts with two decimal places.
Date/Time datatypes -
DATE: Stores dates without time information.
TIME: Stores time of day without date information (optionally with time zone).
TIMESTAMP: Stores both date and time information (optionally with time zone).
INTERVAL: Stores time intervals between two points in time.
Binary types -
BYTEA: Stores variable-length binary data.
BIT: Stores single bits.
BIT VARYING: Stores variable-length bit strings.
Other types -
BOOLEAN: Stores true or false values.
UUID: Stores universally unique identifiers.
XML: Stores XML data.
JSON: Stores JSON data.
ENUM: Stores user-defined enumerated values.
RANGE: Stores ranges of data values.
This list covers the most common datatypes in PostgreSQL.
'''
class PgConnector(DBConnector, ABC):
"""
A connector class for interacting with PostgreSQL databases.
This class provides methods for establishing connections to PostgreSQL instances, executing SQL queries, retrieving results as DataFrames, caching known SQL queries, and managing embeddings. It utilizes the `pg8000` library for connections and the `asyncpg` library for asynchronous operations.
Attributes:
project_id (str): The Google Cloud project ID where the PostgreSQL instance resides.
region (str): The region where the PostgreSQL instance is located.
instance_name (str): The name of the PostgreSQL instance.
database_name (str): The name of the database to connect to.
database_user (str): The username for authentication.
database_password (str): The password for authentication.
pool (Engine): A SQLAlchemy engine object for managing database connections.
Methods:
getconn() -> connection:
Establishes a connection to the PostgreSQL instance and returns a connection object.
retrieve_df(query) -> pd.DataFrame:
Executes a SQL query and returns the results as a pandas DataFrame. Handles potential database errors.
cache_known_sql() -> None:
Caches known good SQL queries into a PostgreSQL table for future reference.
retrieve_matches(mode, user_grouping, qe, similarity_threshold, limit) -> list:
Retrieves similar matches (table schemas, column schemas, or example queries) from the database based on the given mode, query embedding (`qe`), similarity threshold, and limit.
getSimilarMatches(mode, user_grouping, qe, num_matches, similarity_threshold) -> str:
Gets similar matches for tables, columns, or examples asynchronously, formatting the results into a string.
test_sql_plan_execution(generated_sql) -> Tuple[bool, pd.DataFrame]:
Tests the execution plan of a generated SQL query in PostgreSQL. Returns a tuple indicating success and the result DataFrame.
getExactMatches(query) -> str or None:
Checks if the exact question is present in the example SQL set and returns the corresponding SQL query if found.
return_column_schema_sql(schema) -> str:
Returns a SQL query to retrieve column schema information from a PostgreSQL schema.
return_table_schema_sql(schema) -> str:
Returns a SQL query to retrieve table schema information from a PostgreSQL schema.
"""
def __init__(self,
project_id:str,
region:str,
instance_name:str,
database_name:str,
database_user:str,
database_password:str):
self.project_id = project_id
self.region = region
self.instance_name = instance_name
self.database_name = database_name
self.database_user = database_user
self.database_password = database_password
self.pool = create_engine(
"postgresql+pg8000://",
creator=self.getconn,
)
def getconn(self):
"""
function to return the database connection object
"""
# initialize Connector object
connector = Connector()
conn = connector.connect(
f"{self.project_id}:{self.region}:{self.instance_name}",
"pg8000",
user=f"{self.database_user}",
password=f"{self.database_password}",
db=f"{self.database_name}"
)
return conn
def retrieve_df(self, query):
"""
TODO: Description
"""
result_df=pd.DataFrame()
try:
with self.pool.connect() as db_conn:
df = pd.read_sql(text(query), con=db_conn)
result_df = df
# print('\n Return from code execution: ' + str(result_df) )
return result_df
except Exception as e:
print(f"Database Error: {e}")
df = pd.DataFrame({'Error. Message': e}, index=[0])
return df
async def cache_known_sql(self):
df = pd.read_csv(f"{root_dir}/{scripts}/known_good_sql.csv")
df = df.loc[:, ["prompt", "sql", "database_name"]]
df = df.dropna()
loop = asyncio.get_running_loop()
async with Connector(loop=loop) as connector:
# # Create connection to Cloud SQL database.
conn: asyncpg.Connection = await connector.connect_async(
f"{self.project_id}:{self.region}:{self.instance_name}",
"asyncpg",
user=f"{self.database_user}",
password=f"{self.database_password}",
db=f"{self.database_name}",
)
await register_vector(conn)
# Delete the table if it exists.
await conn.execute("DROP TABLE IF EXISTS query_example_embeddings CASCADE")
# Create the `query_example_embeddings` table.
await conn.execute(
"""CREATE TABLE query_example_embeddings(
prompt TEXT,
sql TEXT,
user_grouping TEXT)"""
)
# Copy the dataframe to the 'query_example_embeddings' table.
tuples = list(df.itertuples(index=False))
await conn.copy_records_to_table(
"query_example_embeddings", records=tuples, columns=list(df), timeout=10000
)
await conn.close()
async def retrieve_matches(self, mode, user_groupinguping, qe, similarity_threshold, limit):
"""
This function retrieves the most similar table_schema and column_schema.
Modes can be either 'table', 'column', or 'example'
"""
matches = []
loop = asyncio.get_running_loop()
async with Connector(loop=loop) as connector:
# # Create connection to Cloud SQL database.
conn: asyncpg.Connection = await connector.connect_async(
f"{self.project_id}:{self.region}:{self.instance_name}",
"asyncpg",
user=f"{self.database_user}",
password=f"{self.database_password}",
db=f"{self.database_name}",
)
await register_vector(conn)
# Prepare the SQL depending on 'mode'
if mode == 'table':
sql = """
SELECT content as tables_content,
1 - (embedding <=> $1) AS similarity
FROM table_details_embeddings
WHERE 1 - (embedding <=> $1) > $2
AND user_grouping = $4
ORDER BY similarity DESC LIMIT $3
"""
elif mode == 'column':
sql = """
SELECT content as columns_content,
1 - (embedding <=> $1) AS similarity
FROM tablecolumn_details_embeddings
WHERE 1 - (embedding <=> $1) > $2
AND user_grouping = $4
ORDER BY similarity DESC LIMIT $3
"""
elif mode == 'example':
sql = """
SELECT user_grouping, example_user_question, example_generated_sql,
1 - (embedding <=> $1) AS similarity
FROM example_prompt_sql_embeddings
WHERE 1 - (embedding <=> $1) > $2
AND user_grouping = $4
ORDER BY similarity DESC LIMIT $3
"""
else:
ValueError("No valid mode. Must be either table, column, or example")
name_txt = ''
# print(sql,qe,similarity_threshold,limit,user_grouping)
# FETCH RESULTS FROM POSTGRES DB
results = await conn.fetch(
sql,
qe,
similarity_threshold,
limit,
user_groupinguping
)
# CHECK RESULTS
if len(results) == 0:
print(f"Did not find any results for {mode}. Adjust the query parameters.")
else:
print(f"Found {len(results)} similarity matches for {mode}.")
if mode == 'table':
name_txt = ''
for r in results:
name_txt=name_txt+r["tables_content"]+"\n\n"
elif mode == 'column':
name_txt = ''
for r in results:
name_txt=name_txt+r["columns_content"]+"\n\n "
elif mode == 'example':
name_txt = ''
for r in results:
example_user_question=r["example_user_question"]
example_sql=r["example_generated_sql"]
# print(example_user_question+"\nThreshold::"+str(r["similarity"]))
name_txt = name_txt + "\n Example_question: "+example_user_question+ "; Example_SQL: "+example_sql
else:
ValueError("No valid mode. Must be either table, column, or example")
name_txt = ''
matches.append(name_txt)
# Close the connection to the database.
await conn.close()
return matches
async def getSimilarMatches(self, mode, user_grouping, qe, num_matches, similarity_threshold):
if mode == 'table':
match_result=await self.retrieve_matches(mode, user_grouping, qe, similarity_threshold, num_matches)
match_result = match_result[0]
elif mode == 'column':
match_result=await self.retrieve_matches(mode, user_grouping, qe, similarity_threshold, num_matches)
match_result = match_result[0]
elif mode == 'example':
match_result=await self.retrieve_matches(mode, user_grouping, qe, similarity_threshold, num_matches)
if len(match_result) == 0:
match_result = None
else:
match_result = match_result[0]
return match_result
def test_sql_plan_execution(self, generated_sql):
try:
exec_result_df = pd.DataFrame()
sql = f"""EXPLAIN ANALYZE {generated_sql}"""
exec_result_df = self.retrieve_df(sql)
if not exec_result_df.empty:
if str(exec_result_df.iloc[0]).startswith('Error. Message'):
correct_sql = False
else:
print('\n No need to rewrite the query. This seems to work fine and returned rows...')
correct_sql = True
else:
print('\n No need to rewrite the query. This seems to work fine but no rows returned...')
correct_sql = True
return correct_sql, exec_result_df
except Exception as e:
return False,str(e)
def getExactMatches(self, query):
"""
Checks if the exact question is already present in the example SQL set
"""
check_history_sql=f"""SELECT example_user_question,example_generated_sql
FROM example_prompt_sql_embeddings
WHERE lower(example_user_question) = lower('{query}') LIMIT 1; """
exact_sql_history = self.retrieve_df(check_history_sql)
if exact_sql_history[exact_sql_history.columns[0]].count() != 0:
sql_example_txt = ''
exact_sql = ''
for index, row in exact_sql_history.iterrows():
example_user_question=row["example_user_question"]
example_sql=row["example_generated_sql"]
exact_sql=example_sql
sql_example_txt = sql_example_txt + "\n Example_question: "+example_user_question+ "; Example_SQL: "+example_sql
# print("Found a matching question from the history!" + str(sql_example_txt))
final_sql=exact_sql
else:
print("No exact match found for the user prompt")
final_sql = None
return final_sql
def return_column_schema_sql(self, schema, table_names=None):
"""
This SQL returns a df containing the cols table_schema, table_name, column_name, data_type, column_description, table_description, primary_key, column_constraints
for the schema specified above, e.g. 'retail'
- table_schema: e.g. retail
- table_name: name of the table inside the schema, e.g. products
- column_name: name of each col in each table in the schema, e.g. id_product
- data_type: data type of each col
- column_description: col descriptor, can be empty
- table_description: text descriptor, can be empty
- primary_key: whether the col is PK; if yes, the field contains the col_name
- column_constraints: e.g. "Primary key for this table"
"""
table_filter_clause = ""
if table_names:
# table_names = [name.strip() for name in table_names[1:-1].split(",")] # Handle the string as a list
formatted_table_names = [f"'{name}'" for name in table_names]
table_filter_clause = f"""and table_name in ({', '.join(formatted_table_names)})"""
column_schema_sql = f'''
WITH
columns_schema
AS
(select c.table_schema,c.table_name,c.column_name,c.data_type,d.description as column_description, obj_description(c1.oid) as table_description
from information_schema.columns c
inner join pg_class c1
on c.table_name=c1.relname
inner join pg_catalog.pg_namespace n
on c.table_schema=n.nspname
and c1.relnamespace=n.oid
left join pg_catalog.pg_description d
on d.objsubid=c.ordinal_position
and d.objoid=c1.oid
where
c.table_schema='{schema}' {table_filter_clause}) ,
pk_schema as
(SELECT table_name, column_name AS primary_key
FROM information_schema.key_column_usage
WHERE TABLE_SCHEMA='{schema}' {table_filter_clause}
AND CONSTRAINT_NAME like '%_pkey%'
ORDER BY table_name, primary_key),
fk_schema as
(SELECT table_name, column_name AS foreign_key
FROM information_schema.key_column_usage
WHERE TABLE_SCHEMA='{schema}' {table_filter_clause}
AND CONSTRAINT_NAME like '%_fkey%'
ORDER BY table_name, foreign_key)
select lr.*,
case
when primary_key is not null then 'Primary key for this table'
when foreign_key is not null then CONCAT('Foreign key',column_description)
else null
END as column_constraints
from
(select l.*,r.primary_key
from
columns_schema l
left outer join
pk_schema r
on
l.table_name=r.table_name
and
l.column_name=r.primary_key) lr
left outer join
fk_schema rt
on
lr.table_name=rt.table_name
and
lr.column_name=rt.foreign_key
;
'''
return column_schema_sql
def return_table_schema_sql(self, schema, table_names=None):
"""
This SQL returns a df containing the cols table_schema, table_name, table_description, table_columns (with cols in the table)
for the schema specified above, e.g. 'retail'
- table_schema: e.g. retail
- table_name: name of the table inside the schema, e.g. products
- table_description: text descriptor, can be empty
- table_columns: aggregate of the col names inside the table
"""
table_filter_clause = ""
if table_names:
# Extract individual table names from the input string
#table_names = [name.strip() for name in table_names[1:-1].split(",")] # Handle the string as a list
formatted_table_names = [f"'{name}'" for name in table_names]
table_filter_clause = f"""and table_name in ({', '.join(formatted_table_names)})"""
table_schema_sql = f'''
SELECT table_schema, table_name,table_description, array_to_string(array_agg(column_name), ' , ') as table_columns
FROM
(select c.table_schema,c.table_name,c.column_name,c.ordinal_position,c.column_default,c.data_type,d.description, obj_description(c1.oid) as table_description
from information_schema.columns c
inner join pg_class c1
on c.table_name=c1.relname
inner join pg_catalog.pg_namespace n
on c.table_schema=n.nspname
and c1.relnamespace=n.oid
left join pg_catalog.pg_description d
on d.objsubid=c.ordinal_position
and d.objoid=c1.oid
where
c.table_schema='{schema}' {table_filter_clause} ) data
GROUP BY table_schema, table_name, table_description
ORDER BY table_name;
'''
return table_schema_sql
def get_column_samples(self,columns_df):
sample_column_list=[]
for index, row in columns_df.iterrows():
get_column_sample_sql=f'''SELECT most_common_vals AS sample_values FROM pg_stats WHERE tablename = '{row["table_name"]}' AND schemaname = '{row["table_schema"]}' AND attname = '{row["column_name"]}' '''
column_samples_df=self.retrieve_df(get_column_sample_sql)
# display(column_samples_df)
sample_column_list.append(column_samples_df['sample_values'].to_string(index=False).replace("{","").replace("}",""))
columns_df["sample_values"]=sample_column_list
return columns_df