in env_setup.py [0:0]
def get_embeddings():
"""Generates and returns embeddings for table schemas and column descriptions.
This function performs the following steps:
1. Retrieves table schema and column description data based on the specified data source (BigQuery or PostgreSQL).
2. Generates embeddings for the retrieved data using the configured embedding model.
3. Returns the generated embeddings for both tables and columns.
Returns:
Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing two pandas DataFrames:
- table_schema_embeddings: Embeddings for the table schemas.
- col_schema_embeddings: Embeddings for the column descriptions.
Configuration:
This function relies on the following configuration variables:
- DATA_SOURCE: The source database ("bigquery" or "cloudsql-pg").
- BQ_DATASET_NAME (if DATA_SOURCE is "bigquery"): The BigQuery dataset name.
- BQ_TABLE_LIST (if DATA_SOURCE is "bigquery"): The list of BigQuery tables to process.
- PG_SCHEMA (if DATA_SOURCE is "cloudsql-pg"): The PostgreSQL schema name.
"""
print("Generating embeddings from source db schemas")
import pandas as pd
import os
current_dir = os.getcwd()
root_dir = os.path.expanduser('~') # Start at the user's home directory
while current_dir != root_dir:
for dirpath, dirnames, filenames in os.walk(current_dir):
config_path = os.path.join(dirpath, 'data_source_list.csv')
if os.path.exists(config_path):
file_path = config_path # Update root_dir to the found directory
break # Stop outer loop once found
current_dir = os.path.dirname(current_dir)
print("Source Found at Path :: "+file_path)
# Load the file
df_src = pd.read_csv(file_path)
df_src = df_src.loc[:, ["source", "user_grouping", "schema","table"]]
df_src = df_src.sort_values(by=["source","user_grouping","schema","table"])
#If no schema Error Out
if df_src['schema'].astype(str).str.len().min()==0 or df_src['schema'].isna().any():
raise ValueError("Schema column cannot be empty")
#Group by for all the tables filtered
df=df_src.groupby(['source','schema'])['table'].agg(lambda x: list(x.dropna().unique())).reset_index()
df['table']=df['table'].apply(lambda x: None if pd.isna(x).any() else x)
print("The Embeddings are extracted for the below combinations")
print(df)
table_schema_embeddings=pd.DataFrame(columns=['source_type','join_by','table_schema', 'table_name', 'content','embedding'])
col_schema_embeddings=pd.DataFrame(columns=['source_type','join_by','table_schema', 'table_name', 'column_name', 'content','embedding'])
for _, row in df.iterrows():
DATA_SOURCE = row['source']
SCHEMA = row['schema']
TABLE_LIST = row['table']
_t, _c = retrieve_embeddings(DATA_SOURCE, SCHEMA=SCHEMA, table_names=TABLE_LIST)
_t["source_type"]=DATA_SOURCE
_c["source_type"]=DATA_SOURCE
if not TABLE_LIST:
_t["join_by"]=DATA_SOURCE+"_"+SCHEMA+"_"+SCHEMA
_c["join_by"]=DATA_SOURCE+"_"+SCHEMA+"_"+SCHEMA
table_schema_embeddings = pd.concat([table_schema_embeddings,_t],ignore_index=True)
col_schema_embeddings = pd.concat([col_schema_embeddings,_c],ignore_index=True)
df_src['join_by'] = df_src.apply(
lambda row: f"{row['source']}_{row['schema']}_{row['schema']}" if pd.isna(row['table']) else f"{row['source']}_{row['schema']}_{row['table']}",axis=1)
table_schema_embeddings['join_by'] = table_schema_embeddings['join_by'].fillna(table_schema_embeddings['source_type'] + "_" + table_schema_embeddings['table_schema'] + "_" + table_schema_embeddings['table_name'])
col_schema_embeddings['join_by'] = col_schema_embeddings['join_by'].fillna(col_schema_embeddings['source_type'] + "_" + col_schema_embeddings['table_schema'] + "_" + col_schema_embeddings['table_name'])
table_schema_embeddings = table_schema_embeddings.merge(df_src[['join_by', 'user_grouping']], on='join_by', how='left')
table_schema_embeddings.drop(columns=["join_by"],inplace=True)
#Replace NaN values in group to default to the schema
table_schema_embeddings['user_grouping'] = table_schema_embeddings['user_grouping'].fillna(table_schema_embeddings['table_schema']+"-"+table_schema_embeddings['source_type'])
col_schema_embeddings = col_schema_embeddings.merge(df_src[['join_by', 'user_grouping']], on='join_by', how='left')
col_schema_embeddings.drop(columns=["join_by"],inplace=True)
#Replace NaN values in group to default to the schema
col_schema_embeddings['user_grouping'] = col_schema_embeddings['user_grouping'].fillna(col_schema_embeddings['table_schema']+"-"+col_schema_embeddings['source_type'])
print("Table and Column embeddings are created")
return table_schema_embeddings, col_schema_embeddings