in nl2sql_src/nl2sql_query_embeddings.py [0:0]
def search_matching_queries(self, new_query):
"""
Return 3 most similar queeries and SQLs
"""
tmp = self.extract_data()
df = DataFrame(tmp.fetchall())
# q_embed = df['query_embedding']
# query_embeddings = [item.split(' ') for item in q_embed]
queries_array = df['question']
sql_array = df['sql']
nq_emb = self.embedding_model.get_embeddings([new_query])[0].values
nq_emb_array = np.asarray([nq_emb], dtype=np.float32)
try:
logger.info(f"Trying to read the index file : {self.INDEX_FILE}")
index = read_index(self.INDEX_FILE)
except Exception:
self.recreate_vectordb_index()
index = read_index(self.INDEX_FILE)
scores, id = index.search(nq_emb_array, k=3)
output_json = []
for i in range(len(scores[0])):
res = {}
tmp_sql = ''
res['question'] = queries_array[id[0][i]]
tmp_sql = sql_array[id[0][i]]
tmp_sql = tmp_sql.replace('<dq>', '"')
tmp_sql = tmp_sql.replace("<sq>", "'")
res['sql'] = tmp_sql
output_json.append(res)
return output_json