embeddings/store_embeddings.py (142 lines of code) (raw):
import asyncio
import asyncpg
import pandas as pd
import numpy as np
from pgvector.asyncpg import register_vector
from google.cloud.sql.connector import Connector
from langchain_community.embeddings import VertexAIEmbeddings
from google.cloud import bigquery
from dbconnectors import pgconnector
from agents import EmbedderAgent
from sqlalchemy.sql import text
from utilities import VECTOR_STORE, PROJECT_ID, PG_INSTANCE, PG_DATABASE, PG_USER, PG_PASSWORD, PG_REGION, BQ_OPENDATAQNA_DATASET_NAME, BQ_REGION, EMBEDDING_MODEL
embedder = EmbedderAgent(EMBEDDING_MODEL)
async def store_schema_embeddings(table_details_embeddings,
tablecolumn_details_embeddings,
project_id,
instance_name,
database_name,
schema,
database_user,
database_password,
region,
VECTOR_STORE):
"""
Store the vectorised table and column details in the DB table.
This code may run for a few minutes.
"""
if VECTOR_STORE == "cloudsql-pgvector":
loop = asyncio.get_running_loop()
async with Connector(loop=loop) as connector:
# Create connection to Cloud SQL database.
conn: asyncpg.Connection = await connector.connect_async(
f"{project_id}:{region}:{instance_name}", # Cloud SQL instance connection name
"asyncpg",
user=f"{database_user}",
password=f"{database_password}",
db=f"{database_name}",
)
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
await register_vector(conn)
# await conn.execute(f"DROP SCHEMA IF EXISTS {pg_schema} CASCADE")
# await conn.execute(f"CREATE SCHEMA {pg_schema}")
# await conn.execute("DROP TABLE IF EXISTS table_details_embeddings")
# Create the `table_details_embeddings` table to store vector embeddings.
await conn.execute(
"""CREATE TABLE IF NOT EXISTS table_details_embeddings(
source_type VARCHAR(100) NOT NULL,
user_grouping VARCHAR(100) NOT NULL,
table_schema VARCHAR(1024) NOT NULL,
table_name VARCHAR(1024) NOT NULL,
content TEXT,
embedding vector(768))"""
)
# Store all the generated embeddings back into the database.
for index, row in table_details_embeddings.iterrows():
await conn.execute(
f"""
MERGE INTO table_details_embeddings AS target
USING (SELECT $1::text AS source_type, $2::text AS user_grouping, $3::text AS table_schema, $4::text AS table_name, $5::text AS content, $6::vector AS embedding) AS source
ON target.user_grouping = source.user_grouping AND target.table_name = source.table_name
WHEN MATCHED THEN
UPDATE SET source_type = source.source_type, table_schema = source.table_schema, content = source.content, embedding = source.embedding
WHEN NOT MATCHED THEN
INSERT (source_type, user_grouping, table_schema, table_name, content, embedding)
VALUES (source.source_type, source.user_grouping, source.table_schema, source.table_name, source.content, source.embedding);
""",
row["source_type"],
row["user_grouping"],
row["table_schema"],
row["table_name"],
row["content"],
np.array(row["embedding"]),
)
# await conn.execute("DROP TABLE IF EXISTS tablecolumn_details_embeddings")
# Create the `table_details_embeddings` table to store vector embeddings.
await conn.execute(
"""CREATE TABLE IF NOT EXISTS tablecolumn_details_embeddings(
source_type VARCHAR(100) NOT NULL,
user_grouping VARCHAR(100) NOT NULL,
table_schema VARCHAR(1024) NOT NULL,
table_name VARCHAR(1024) NOT NULL,
column_name VARCHAR(1024) NOT NULL,
content TEXT,
embedding vector(768))"""
)
# Store all the generated embeddings back into the database.
for index, row in tablecolumn_details_embeddings.iterrows():
await conn.execute(
f"""
MERGE INTO tablecolumn_details_embeddings AS target
USING (SELECT $1::text AS source_type, $2::text AS user_grouping, $3::text AS table_schema,
$4::text AS table_name, $5::text AS column_name, $6::text AS content, $7::vector AS embedding) AS source
ON target.user_grouping = source.user_grouping
AND target.table_name = source.table_name
AND target.column_name = source.column_name
WHEN MATCHED THEN
UPDATE SET source_type = source.source_type, table_schema = source.table_schema, content = source.content, embedding = source.embedding
WHEN NOT MATCHED THEN
INSERT (source_type, user_grouping, table_schema, table_name, column_name, content, embedding)
VALUES (source.source_type, source.user_grouping, source.table_schema, source.table_name, source.column_name, source.content, source.embedding);
""",
row["source_type"],
row["user_grouping"],
row["table_schema"],
row["table_name"],
row["column_name"],
row["content"],
np.array(row["embedding"]),
)
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
await register_vector(conn)
# await conn.execute("DROP TABLE IF EXISTS example_prompt_sql_embeddings")
await conn.execute(
"""CREATE TABLE IF NOT EXISTS example_prompt_sql_embeddings(
user_grouping VARCHAR(1024) NOT NULL,
example_user_question text NOT NULL,
example_generated_sql text NOT NULL,
embedding vector(768))"""
)
await conn.close()
elif VECTOR_STORE == "bigquery-vector":
client=bigquery.Client(project=project_id)
#Store table embeddings
client.query_and_wait(f'''CREATE TABLE IF NOT EXISTS `{project_id}.{schema}.table_details_embeddings` (
source_type string NOT NULL, user_grouping string NOT NULL, table_schema string NOT NULL, table_name string NOT NULL, content string, embedding ARRAY<FLOAT64>)''')
#job_config = bigquery.LoadJobConfig(write_disposition="WRITE_TRUNCATE")
delete_conditions = table_details_embeddings[['user_grouping', 'table_name']].apply(tuple, axis=1).tolist()
where_clause = " OR ".join([f"(user_grouping = '{cond[0]}' AND table_name = '{cond[1]}')" for cond in delete_conditions])
delete_query = f"""
DELETE FROM `{project_id}.{schema}.table_details_embeddings`
WHERE {where_clause}
"""
client.query_and_wait(delete_query)
client.load_table_from_dataframe(table_details_embeddings,f'{project_id}.{schema}.table_details_embeddings')
#Store column embeddings
client.query_and_wait(f'''CREATE TABLE IF NOT EXISTS `{project_id}.{schema}.tablecolumn_details_embeddings` (
source_type string NOT NULL,user_grouping string NOT NULL, table_schema string NOT NULL, table_name string NOT NULL, column_name string NOT NULL,
content string, embedding ARRAY<FLOAT64>)''')
#job_config = bigquery.LoadJobConfig(write_disposition="WRITE_TRUNCATE")
delete_conditions = tablecolumn_details_embeddings[['user_grouping', 'table_name', 'column_name']].apply(tuple, axis=1).tolist()
where_clause = " OR ".join([f"(user_grouping = '{cond[0]}' AND table_name = '{cond[1]}' AND column_name = '{cond[2]}')" for cond in delete_conditions])
delete_query = f"""
DELETE FROM `{project_id}.{schema}.tablecolumn_details_embeddings`
WHERE {where_clause}
"""
client.query_and_wait(delete_query)
client.load_table_from_dataframe(tablecolumn_details_embeddings,f'{project_id}.{schema}.tablecolumn_details_embeddings')
client.query_and_wait(f'''CREATE TABLE IF NOT EXISTS `{project_id}.{schema}.example_prompt_sql_embeddings` (
user_grouping string NOT NULL, example_user_question string NOT NULL, example_generated_sql string NOT NULL,
embedding ARRAY<FLOAT64>)''')
else: raise ValueError("Please provide a valid Vector Store.")
return "Embeddings are stored successfully"
async def add_sql_embedding(user_question, generated_sql, database):
emb=embedder.create(user_question)
if VECTOR_STORE == "cloudsql-pgvector":
# sql= f'''MERGE INTO example_prompt_sql_embeddings as tgt
# using (SELECT '{user_question}' as example_user_question) as src
# on tgt.example_user_question=src.example_user_question
# when not matched then
# insert (table_schema, example_user_question,example_generated_sql,embedding)
# values('{database}','{user_question}','{generated_sql}','{(emb)}')
# when matched then update set
# table_schema = '{database}',
# example_generated_sql = '{generated_sql}',
# embedding = '{(emb)}' '''
# # print(sql)
# conn=pgconnector.pool.connect()
# await conn.execute(text(sql))
# pgconnector.retrieve_df(sql)
loop = asyncio.get_running_loop()
async with Connector(loop=loop) as connector:
# Create connection to Cloud SQL database.
conn: asyncpg.Connection = await connector.connect_async(
f"{PROJECT_ID}:{PG_REGION}:{PG_INSTANCE}", # Cloud SQL instance connection name
"asyncpg",
user=f"{PG_USER}",
password=f"{PG_PASSWORD}",
db=f"{PG_DATABASE}",
)
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
await register_vector(conn)
await conn.execute("DELETE FROM example_prompt_sql_embeddings WHERE user_grouping= $1 and example_user_question=$2",
database,
user_question)
cleaned_sql =generated_sql.replace("\r", " ").replace("\n", " ")
await conn.execute(
"INSERT INTO example_prompt_sql_embeddings (user_grouping, example_user_question, example_generated_sql, embedding) VALUES ($1, $2, $3, $4)",
database,
user_question,
cleaned_sql,
np.array(emb),
)
elif VECTOR_STORE == "bigquery-vector":
client=bigquery.Client(project=PROJECT_ID)
client.query_and_wait(f'''CREATE TABLE IF NOT EXISTS `{PROJECT_ID}.{BQ_OPENDATAQNA_DATASET_NAME}.example_prompt_sql_embeddings` (
user_grouping string NOT NULL, example_user_question string NOT NULL, example_generated_sql string NOT NULL,
embedding ARRAY<FLOAT64>)''')
client.query_and_wait(f'''DELETE FROM `{PROJECT_ID}.{BQ_OPENDATAQNA_DATASET_NAME}.example_prompt_sql_embeddings`
WHERE user_grouping= '{database}' and example_user_question= "{user_question}" '''
)
# embedding=np.array(row["embedding"])
cleaned_sql = generated_sql.replace("\r", " ").replace("\n", " ")
client.query_and_wait(f'''INSERT INTO `{PROJECT_ID}.{BQ_OPENDATAQNA_DATASET_NAME}.example_prompt_sql_embeddings`
VALUES ("{database}","{user_question}" ,
"{cleaned_sql}",{emb})''')
return 1
if __name__ == '__main__':
from retrieve_embeddings import retrieve_embeddings
from utilities import PG_SCHEMA, PROJECT_ID, PG_INSTANCE, PG_DATABASE, PG_USER, PG_PASSWORD, PG_REGION
VECTOR_STORE = "cloudsql-pgvector"
t, c = retrieve_embeddings(VECTOR_STORE, PG_SCHEMA)
asyncio.run(store_schema_embeddings(t,
c,
PROJECT_ID,
PG_INSTANCE,
PG_DATABASE,
PG_SCHEMA,
PG_USER,
PG_PASSWORD,
PG_REGION,
VECTOR_STORE = VECTOR_STORE))