def search_matching_queries()

in nl2sql_src/nl2sql_query_embeddings.py [0:0]


    def search_matching_queries(self, new_query):
        """
            Return 3 most similar queeries and SQLs
        """
        tmp = self.extract_data()
        df = DataFrame(tmp.fetchall())

        # q_embed = df['query_embedding']
        # query_embeddings = [item.split(' ') for item in q_embed]

        queries_array = df['question']
        sql_array = df['sql']

        nq_emb = self.embedding_model.get_embeddings([new_query])[0].values
        nq_emb_array = np.asarray([nq_emb], dtype=np.float32)

        try:
            logger.info(f"Trying to read the index file : {self.INDEX_FILE}")
            index = read_index(self.INDEX_FILE)
        except Exception:
            self.recreate_vectordb_index()
            index = read_index(self.INDEX_FILE)

        scores, id = index.search(nq_emb_array, k=3)

        output_json = []
        for i in range(len(scores[0])):
            res = {}
            tmp_sql = ''
            res['question'] = queries_array[id[0][i]]

            tmp_sql = sql_array[id[0][i]]
            tmp_sql = tmp_sql.replace('<dq>', '"')
            tmp_sql = tmp_sql.replace("<sq>", "'")
            res['sql'] = tmp_sql
            output_json.append(res)

        return output_json