embeddings/retrieve_embeddings.py (90 lines of code) (raw):
import re
import io
import sys
import pandas as pd
from dbconnectors import pgconnector,bqconnector
from agents import EmbedderAgent, ResponseAgent, DescriptionAgent
from utilities import EMBEDDING_MODEL, DESCRIPTION_MODEL, USE_COLUMN_SAMPLES
embedder = EmbedderAgent(EMBEDDING_MODEL)
# responder = ResponseAgent('gemini-1.0-pro')
descriptor = DescriptionAgent(DESCRIPTION_MODEL)
def get_embedding_chunked(textinput, batch_size):
for i in range(0, len(textinput), batch_size):
request = [x["content"] for x in textinput[i : i + batch_size]]
response = embedder.create(request) # Vertex Textmodel Embedder
# Store the retrieved vector embeddings for each chunk back.
for x, e in zip(textinput[i : i + batch_size], response):
x["embedding"] = e
# Store the generated embeddings in a pandas dataframe.
out_df = pd.DataFrame(textinput)
return out_df
def retrieve_embeddings(SOURCE, SCHEMA="public", table_names = None):
""" Augment all the DB schema blocks to create document for embedding """
if SOURCE == "cloudsql-pg":
table_schema_sql = pgconnector.return_table_schema_sql(SCHEMA,table_names=table_names)
table_desc_df = pgconnector.retrieve_df(table_schema_sql)
column_schema_sql = pgconnector.return_column_schema_sql(SCHEMA,table_names=table_names)
column_name_df = pgconnector.retrieve_df(column_schema_sql)
#GENERATE MISSING DESCRIPTIONS
table_desc_df,column_name_df= descriptor.generate_missing_descriptions(SOURCE,table_desc_df,column_name_df)
#ADD SAMPLES VALUES FOR COLUMNS
column_name_df["sample_values"]=None
if USE_COLUMN_SAMPLES:
column_name_df = pgconnector.get_column_samples(column_name_df)
### TABLE EMBEDDING ###
"""
This SQL returns a df containing the cols table_schema, table_name, table_description, table_columns (with cols in the table)
for the schema specified above, e.g. 'retail'
"""
table_details_chunked = []
for index_aug, row_aug in table_desc_df.iterrows():
cur_table_name = str(row_aug['table_name'])
cur_table_schema = str(row_aug['table_schema'])
curr_col_names = str(row_aug['table_columns'])
curr_tbl_desc = str(row_aug['table_description'])
table_detailed_description=f"""
Table Name: {cur_table_name} |
Schema Name: {cur_table_schema} |
Table Description - {curr_tbl_desc}) |
Columns List: [{curr_col_names}]"""
r = {"table_schema": cur_table_schema,"table_name": cur_table_name,"content": table_detailed_description}
table_details_chunked.append(r)
table_details_embeddings = get_embedding_chunked(table_details_chunked, 10)
### COLUMN EMBEDDING ###
"""
This SQL returns a df containing the cols table_schema, table_name, column_name, data_type, column_description, table_description, primary_key, column_constraints
for the schema specified above, e.g. 'retail'
"""
column_details_chunked = []
for index_aug, row_aug in column_name_df.iterrows():
cur_table_name = str(row_aug['table_name'])
cur_table_owner = str(row_aug['table_schema'])
curr_col_name = str(row_aug['table_schema'])+'.'+str(row_aug['table_name'])+'.'+str(row_aug['column_name'])
curr_col_datatype = str(row_aug['data_type'])
curr_col_description = str(row_aug['column_description'])
curr_col_constraints = str(row_aug['column_constraints'])
curr_column_name = str(row_aug['column_name'])
curr_column_samples = str(row_aug['sample_values'])
column_detailed_description=f"""Schema Name:{cur_table_owner} | Column Name: {curr_col_name} (Data type: {curr_col_datatype}) | Table Name: {cur_table_name} | (column description: {curr_col_description})(constraints: {curr_col_constraints}) | (Sample Values in the Column: {curr_column_samples})"""
r = {"table_schema": cur_table_owner,"table_name": cur_table_name,"column_name":curr_column_name, "content": column_detailed_description}
column_details_chunked.append(r)
column_details_embeddings = get_embedding_chunked(column_details_chunked, 10)
elif SOURCE=='bigquery':
table_schema_sql = bqconnector.return_table_schema_sql(SCHEMA, table_names=table_names)
table_desc_df = bqconnector.retrieve_df(table_schema_sql)
column_schema_sql = bqconnector.return_column_schema_sql(SCHEMA, table_names=table_names)
column_name_df = bqconnector.retrieve_df(column_schema_sql)
#GENERATE MISSING DESCRIPTIONS
table_desc_df,column_name_df= descriptor.generate_missing_descriptions(SOURCE,table_desc_df,column_name_df)
#ADD SAMPLES VALUES FOR COLUMNS
column_name_df["sample_values"]=None
if USE_COLUMN_SAMPLES:
column_name_df = bqconnector.get_column_samples(column_name_df)
#TABLE EMBEDDINGS
table_details_chunked = []
for index_aug, row_aug in table_desc_df.iterrows():
cur_project_name =str(row_aug['project_id'])
cur_table_name = str(row_aug['table_name'])
cur_table_schema = str(row_aug['table_schema'])
curr_col_names = str(row_aug['table_columns'])
curr_tbl_desc = str(row_aug['table_description'])
table_detailed_description=f"""
Full Table Name : {cur_project_name}.{cur_table_schema}.{cur_table_name} |
Table Columns List: [{curr_col_names}] |
Table Description: {curr_tbl_desc} """
r = {"table_schema": cur_table_schema,"table_name": cur_table_name,"content": table_detailed_description}
table_details_chunked.append(r)
table_details_embeddings = get_embedding_chunked(table_details_chunked, 10)
### COLUMN EMBEDDING ###
"""
This SQL returns a df containing the cols table_schema, table_name, column_name, data_type, column_description, table_description, primary_key, column_constraints
for the schema specified above, e.g. 'retail'
"""
column_details_chunked = []
for index_aug, row_aug in column_name_df.iterrows():
cur_project_name =str(row_aug['project_id'])
cur_table_name = str(row_aug['table_name'])
cur_table_owner = str(row_aug['table_schema'])
curr_col_name = str(row_aug['table_schema'])+'.'+str(row_aug['table_name'])+'.'+str(row_aug['column_name'])
curr_col_datatype = str(row_aug['data_type'])
curr_col_description = str(row_aug['column_description'])
curr_col_constraints = str(row_aug['column_constraints'])
curr_column_name = str(row_aug['column_name'])
curr_column_samples = str(row_aug['sample_values'])
column_detailed_description=f"""
Column Name: {curr_col_name}|
Full Table Name : {cur_project_name}.{cur_table_schema}.{cur_table_name} |
Data type: {curr_col_datatype}|
Column description: {curr_col_description}|
Column Constraints: {curr_col_constraints}|
Sample Values in the Column : {curr_column_samples}"""
r = {"table_schema": cur_table_owner,"table_name": cur_table_name,"column_name":curr_column_name, "content": column_detailed_description}
column_details_chunked.append(r)
column_details_embeddings = get_embedding_chunked(column_details_chunked, 10)
return table_details_embeddings, column_details_embeddings
if __name__ == '__main__':
SOURCE = 'cloudsql-pg'
t, c = retrieve_embeddings(SOURCE, SCHEMA="public")