nl2sql_library/nl2sql_query_embeddings.py (232 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# import pandas as pd
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
# from google.cloud import aiplatform
# from google.cloud import storage
from vertexai.preview.language_models import TextEmbeddingModel
# from io import StringIO
# import csv
# from vertexai.language_models import CodeGenerationModel
# import pickle
import json
import faiss
from faiss import write_index, read_index
# from tqdm import tqdm
# from sentence_transformers import SentenceTransformer
import os
from pandas import DataFrame
from google.cloud.sql.connector import Connector, IPTypes
import pg8000
import sqlalchemy
from loguru import logger
ip_type = IPTypes.PRIVATE if os.environ.get("PRIVATE_IP") else IPTypes.PUBLIC
class Nl2Sql_embed():
"""
Local Embeddings and Local Vector DB class
"""
def __init__(self):
# Init function
self.EMBEDDING_FILE = "utils/embeddings.json"
self.INDEX_FILE = 'utils/saved_index_localdata'
self.embedding_model =\
TextEmbeddingModel.from_pretrained("textembedding-gecko@003")
def generate_embedding(self, query, sql='blank sql'):
"""
Generates text embeddings
"""
# Replace this with your actual embedding generation
# using text-gecko003 or another model
q_embeddings = self.embedding_model.get_embeddings([query])[0].values
sql_embeddings = self.embedding_model.get_embeddings([sql])[0].values
return q_embeddings, sql_embeddings
# def generate_bert_embeddings(self, documents):
# # Load pre-trained BERT model
# model = SentenceTransformer('bert-base-nli-mean-tokens')
# # Generate BERT embeddings for documents
# embeddings = model.encode(documents)
# return embeddings
def insert_data(self, question, sql):
"""
Inserts data to Embeddings file
"""
logger.info(f"Inserting data. Question : {question}, SQL : {sql}")
try:
with open(self.EMBEDDING_FILE, "r") as f:
data = json.load(f)
except FileNotFoundError:
data = []
q_emb, sql_emb = self.generate_embedding(question, sql)
data.append({
"question": question,
"sql": sql,
"question_embedding": q_emb,
"sql_embedding": sql_emb
})
with open(self.EMBEDDING_FILE, "w") as f:
json.dump(data, f)
self.update_vectordb_index(query=question)
def load_embeddings(self):
"""
Read the Embeddigs.json file to memory
"""
with open(self.EMBEDDING_FILE, "r") as f:
data = json.load(f)
return data
def distance(self, embedding1, embedding2):
"""Calculates negative cosine similarity"""
return -cosine_similarity([embedding1], [embedding2])[0][0]
def find_closest_questions(self, new_question, data, n=3):
"""
Return 3 most similar queeries and SQLs
"""
new_embedding, _ = self.generate_embedding(new_question)
distances = [
self.distance(
new_embedding,
item["question_embedding"]) for item in data]
closest_indices = np.argsort(distances)[:n]
return [(data[i]['question'], data[i]['sql']) for i in closest_indices]
def create_vectordb_index(self):
"""
Recreate VectorDB indes file
"""
embeddings_data = self.load_embeddings()
query_embeddings = [
item['question_embedding'] for item in embeddings_data
]
# query_array_updated = [[item['question'],
# item['sql']] for item in embeddings_data]
embeddings_data_array = np.asarray(query_embeddings, dtype=np.float32)
index = faiss.IndexFlatIP(len(query_embeddings[0]))
index.add(embeddings_data_array)
write_index(index, self.INDEX_FILE)
# return index, query_array_updated
return
def update_vectordb_index(self, query):
"""
Update the Vector DB index file
"""
emb = self.embedding_model.get_embeddings([query])[0].values
new_array = [emb]
embeddings_data_array = np.asarray(new_array, dtype=np.float32)
# Read the index from stored index file
try:
index = read_index(self.INDEX_FILE)
except Exception:
index = faiss.IndexFlatIP(len(new_array[0]))
index.add(embeddings_data_array)
write_index(index, self.INDEX_FILE)
def search_matching_queries(self, new_query):
"""
Return 3 most similar queeries and SQLs
"""
embeddings_data = self.load_embeddings()
query_array_updated = [[item['question'],
item['sql']] for item in embeddings_data]
nq_emb = self.embedding_model.get_embeddings([new_query])[0].values
nq_emb_array = np.asarray([nq_emb], dtype=np.float32)
index = read_index(self.INDEX_FILE)
scores, id = index.search(nq_emb_array, k=3)
output_json = []
for i in range(len(scores[0])):
res = {}
res['question'] = query_array_updated[id[0][i]][0]
res['sql'] = query_array_updated[id[0][i]][1]
output_json.append(res)
return output_json
class PgSqlEmb():
"""
PostgreSQL DB interface class
"""
def __init__(self,
proj_id,
loc,
pg_inst,
pg_db,
pg_uname,
pg_pwd,
pg_table='documents',
index_file='saved_index_pgdata'):
# Init function
# self.EMBEDDING_FILE = "embeddings.json"
self.PGPROJ = proj_id
self.PGLOCATION = loc
self.PGINSTANCE = pg_inst
self.CONNSTRING = f"{self.PGPROJ}:{self.PGLOCATION}:{self.PGINSTANCE}"
self.USER = pg_uname
self.PWD = pg_pwd
self.PGDB = pg_db
self.PGTABLE = pg_table
# self.INDEX_FILE = 'saved_index_pgdata'
self.INDEX_FILE =\
f"../../nl2sql-generic/nl2sql_src/cache_metadata/{index_file}"
self.embedding_model =\
TextEmbeddingModel.from_pretrained("textembedding-gecko@003")
self.pool = self.getpool()
def getconn(self) -> pg8000.dbapi.Connection:
"""
Get DB connection
"""
connector = Connector()
conn: pg8000.dbapi.Connection = connector.connect(
self.CONNSTRING,
"pg8000",
user=self.USER,
password=self.PWD,
db=self.PGDB,
ip_type=ip_type,
)
return conn
def getpool(self):
"""
return connection pool
"""
pool = sqlalchemy.create_engine(
"postgresql+pg8000://",
creator=self.getconn,
# ...
)
return pool
def create_table(self):
"""
Create table in PostgreSQL Db
"""
sql_create = f"""CREATE TABLE IF NOT EXISTS {self.PGTABLE} (
q_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY,
question TEXT,
sql TEXT,
query_embedding TEXT
);"""
with self.pool.connect() as conn:
conn.execute(sql_create)
def empty_table(self, remove_index=True):
"""
Delete all rows in the PostgreSQL DB
"""
sql_clear = f'DELETE from {self.PGTABLE}'
with self.pool.connect() as conn:
conn.execute(sql_clear)
if remove_index:
try:
os.remove(self.INDEX_FILE)
except Exception:
pass
def insert_row(self, query, sql):
"""
Insert question and embeddings to PostgreSQL DB
"""
sql = sql.replace("'", "<sq>")
sql = sql.replace('"', '<dq>')
emb = self.embedding_model.get_embeddings([query])[0].values
sql_ins = f"INSERT INTO {self.PGTABLE}\
(question, sql, query_embedding) values\
('{query}', '{sql}', '{emb}')"
with self.pool.connect() as conn:
conn.execute(sql_ins)
self.update_vectordb_index(query)
def extract_data(self):
"""
REturn all data from DB
"""
sql_data = f'SELECT * FROM {self.PGTABLE}'
with self.pool.connect() as conn:
data = conn.execute(sql_data)
return data
def extract_pg_embeddings(self):
"""
Extract embeddings data fro PG database
"""
tmp = self.extract_data()
df = DataFrame(tmp.fetchall())
q_embed = df['query_embedding']
len(q_embed)
query_embeddings = [item.split(' ') for item in q_embed]
new_array = []
for elem in query_embeddings:
new_row = []
# Need to make some computations to convert the
# embeddings stored as string to array of floats
first_elem = elem[0].split('[')[1].split(',')[0]
new_row.append(float(first_elem))
for i in range(1, len(elem)-1):
t_elem = elem[i].split(',')[0]
new_row.append(float(t_elem))
last_elem = elem[len(elem)-1].split(']')[0]
new_row.append(float(last_elem))
new_array.append(new_row)
return df['question'], df['sql'], new_array
def recreate_vectordb_index(self):
"""
Regenerate VectorDB file from PG Table data
"""
tmp = self.extract_data()
df = DataFrame(tmp.fetchall())
q_embed = df['query_embedding']
query_embeddings = [item.split(' ') for item in q_embed]
new_array = []
for elem in query_embeddings:
new_row = []
# Need to make some computations to convert the
# embeddings stored as string to array of floats
first_elem = elem[0].split('[')[1].split(',')[0]
new_row.append(float(first_elem))
for i in range(1, len(elem)-1):
t_elem = elem[i].split(',')[0]
new_row.append(float(t_elem))
last_elem = elem[len(elem)-1].split(']')[0]
new_row.append(float(last_elem))
new_array.append(new_row)
embeddings_data_array = np.asarray(new_array, dtype=np.float32)
index = faiss.IndexFlatIP(len(query_embeddings[0]))
index.add(embeddings_data_array)
write_index(index, self.INDEX_FILE)
return
def update_vectordb_index(self, query):
"""
Update VectorDB on every query insert
"""
emb = self.embedding_model.get_embeddings([query])[0].values
new_array = [emb]
embeddings_data_array = np.asarray(new_array, dtype=np.float32)
# Read the index from stored index file
try:
index = read_index(self.INDEX_FILE)
except Exception:
index = faiss.IndexFlatIP(len(new_array[0]))
index.add(embeddings_data_array)
write_index(index, self.INDEX_FILE)
return
def search_matching_queries(self, new_query):
"""
Return 3 most similar queeries and SQLs
"""
tmp = self.extract_data()
df = DataFrame(tmp.fetchall())
# q_embed = df['query_embedding']
# query_embeddings = [item.split(' ') for item in q_embed]
queries_array = df['question']
sql_array = df['sql']
nq_emb = self.embedding_model.get_embeddings([new_query])[0].values
nq_emb_array = np.asarray([nq_emb], dtype=np.float32)
try:
logger.info(f"Trying to read the index file : {self.INDEX_FILE}")
index = read_index(self.INDEX_FILE)
except Exception:
self.recreate_vectordb_index()
index = read_index(self.INDEX_FILE)
scores, id = index.search(nq_emb_array, k=3)
output_json = []
for i in range(len(scores[0])):
res = {}
tmp_sql = ''
res['question'] = queries_array[id[0][i]]
tmp_sql = sql_array[id[0][i]]
tmp_sql = tmp_sql.replace('<dq>', '"')
tmp_sql = tmp_sql.replace("<sq>", "'")
res['sql'] = tmp_sql
output_json.append(res)
return output_json