route.py (664 lines of code) (raw):

# -*- coding: utf-8 -*- import argparse import copy import csv import json import re import sqlite3 import traceback import os from vllm import LLM, SamplingParams from func_timeout import func_set_timeout import func_timeout import tqdm prompt_temp = """Given the following database schema and question, your task is to write a valid SQL query whose execution results can accurately answer the question. /* Database schema */ {ds} /* Sample rows of each table */ {sr} /* Question */ {qs}{hint} Answer the question by a SQL query only with no explanation: """ prompt_sl_temp_sft = """Given the following database schema and question, your task is to extract the tables and columns relevant to solving the question. /* Database schema */ {ds} /* Sample rows of each table */ {sr} /* Question */ {qs}{hint} Output the tables and columns only with no explanation: """ prompt_sl_temp = """Given the following database schema and question, your task is to extract the tables and columns relevant to solving the question. /* Examples */ Example 1: Database schema: CREATE TABLE department (Department_ID NUMBER, Name TEXT, Creation TEXT, Ranking NUMBER, Budget_in_Billions NUMBER, Num_Employees NUMBER, PRIMARY KEY(Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees)); CREATE TABLE head (head_ID NUMBER, name TEXT, born_state TEXT, age NUMBER, PRIMARY KEY(head_ID, name, born_state, age)); CREATE TABLE management (department_ID NUMBER, head_ID NUMBER, temporary_acting TEXT, PRIMARY KEY(department_ID, head_ID, temporary_acting), FOREIGN KEY (head_ID) REFERENCES head(head_ID), FOREIGN KEY (department_ID) REFERENCES department(Department_ID)); Question: What are the names of the heads who are born outside the California state? Output: {{"head": ["name", "born_state"]}} Example 2: Databse schema: CREATE TABLE department (Department_ID NUMBER, Name TEXT, Creation TEXT, Ranking NUMBER, Budget_in_Billions NUMBER, Num_Employees NUMBER, PRIMARY KEY(Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees)); CREATE TABLE head (head_ID NUMBER, name TEXT, born_state TEXT, age NUMBER, PRIMARY KEY(head_ID, name, born_state, age)); CREATE TABLE management (department_ID NUMBER, head_ID NUMBER, temporary_acting TEXT, PRIMARY KEY(department_ID, head_ID, temporary_acting), FOREIGN KEY (head_ID) REFERENCES head(head_ID), FOREIGN KEY (department_ID) REFERENCES department(Department_ID)); Question: How many departments are led by heads who are not mentioned? Output: {{"department": ["Department_ID"], "management": ["department_ID"]}} Example 3: Database schema: CREATE TABLE department (Department_ID NUMBER, Name TEXT, Creation TEXT, Ranking NUMBER, Budget_in_Billions NUMBER, Num_Employees NUMBER, PRIMARY KEY(Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees)); CREATE TABLE head (head_ID NUMBER, name TEXT, born_state TEXT, age NUMBER, PRIMARY KEY(head_ID, name, born_state, age)); CREATE TABLE management (department_ID NUMBER, head_ID NUMBER, temporary_acting TEXT, PRIMARY KEY(department_ID, head_ID, temporary_acting), FOREIGN KEY (head_ID) REFERENCES head(head_ID), FOREIGN KEY (department_ID) REFERENCES department(Department_ID)); Question: How many heads of the departments are older than 56? Output: {{'head': ['age']}} Now, let’s get started! /* Database schema */ {ds} /* Sample rows of each table */ {sr} /* Question */ {qs}{hint} Output the tables and columns only with no explanation: """ prompt_nc_temp_sft = """Your task is to determine whether the execution results of a SQL query can answer the given question according to the following database schema. If the execution results cannot correctly answer the question, please give me the correct SQL query. /* Database schema */ {ds} /* Sample rows of each table */ {sr} /* Question */ {qs}{hint} /* SQL query */ {sql}{ex_hint} Output: """ prompt_nc_temp = """Your task is to determine whether the execution results of a SQL query can answer the given question according to the following database schema. If the execution results cannot correctly answer the question, please give me the correct SQL query. /* Examples */ Example 1: Question: Average of the last receipt cost of the products whose average lead time is 60 days. SQL query: SELECT SUM(LastReceiptCost) / COUNT(ProductID) FROM ProductVendor; Output: The execution results of the SQL query cannot correctly answer the question. The correct SQL query should be: ```sql SELECT SUM(LastReceiptCost) / COUNT(ProductID) FROM ProductVendor WHERE AverageLeadTime = 60; ``` Example 2: Question: Calculate the average price of products shipped to the UK. SQL query: SELECT AVG(UnitPrice) AS avg FROM Invoices WHERE Country = 'UK'; Output: The execution results of the SQL query can correctly answer the question. Example 3: Question: What is the total cost for all the orders placed on 5/29/2013? SQL query: SELECT SUM(TotalDue) FROM PurchaseOrderHeader WHERE OrderDate LIKE '2013-05-29%'; Output: The SQL query can correctly answer the question. Now, let’s get started! /* Database schema */ {ds} /* Sample rows of each table */ {sr} /* Question */ {qs}{hint} /* SQL query */ {sql}{ex_hint} Output: """ prompt_cw_temp_sft = """Given the following database schema and question, your task is to write an incomplete SQL query into a complete SQL query whose execution results can correctly answer the question. /* Database schema */ {ds} /* Sample rows of each table */ {sr} /* Question */ {qs}{hint} /* The incomplete SQL query */ ```sql {sql} ``` Output: """ prompt_cw_temp = """Given the following database schema and question, your task is to write an incomplete SQL query into a complete SQL query whose execution results can correctly answer the question. /* Examples */ Example 1: Question: How many heads of the departments are older than 56 ? The incomplete SQL query: ```sql\nELECT count(*);\n``` Output: ```sql\nELECT count(*) FROM head WHERE age > 56;\n``` Example 2: Question: What are the distinct creation years of the departments managed by a secretary born in state 'Alabama'? The incomplete SQL query: ```sql\nSELECT DISTINCT T1.creation FROM department AS T1;\n``` Output: ```sql\nSELECT DISTINCT T1.creation FROM department AS T1 JOIN management AS T2 ON T1.department_id = T2.department_id JOIN head AS T3 ON T2.head_id = T3.head_id WHERE T3.born_state = 'Alabama';\n``` Example 3: Question: Show the name and number of employees for the departments managed by heads whose temporary acting value is 'Yes'? The incomplete SQL query: ```sql\nSELECT T1.name, T1.num_employees FROM department AS T1 JOIN management AS T2;\n``` Output: ```sql\nSELECT T1.name, T1.num_employees FROM department AS T1 JOIN management AS T2 ON T1.department_id = T2.department_id WHERE T2.temporary_acting = 'Yes';\n``` Now, let’s get started! /* Database schema */ {ds} /* Sample rows of each table */ {sr} /* Question */ {qs}{hint} /* The incomplete SQL query */ ```sql {sql} ``` Output: """ def read_json_file(file_path): try: with open(file_path, 'r', encoding='utf-8-sig') as file: data = json.load(file) return data except Exception as e: print("="*10,e) return None class LLM_Model(object): def __init__(self, model= ''): self.model = model model = model.lower().replace('_','').replace('-','') if 'qwen2' in model: self.tag ='qwen2' elif 'llama3' in model: self.tag ='llama3' elif 'llama2' in model: self.tag ='llam2' elif 'deepseek' in model: self.tag ='deepseek' elif 'mistral' in model: self.tag ='mistral' elif 'codellama' in model: self.tag = 'codellama' else: raise TypeError(f"Unexpect model: {model}.") self.llm = LLM(model=self.model, seed=123, gpu_memory_utilization=0.9, tensor_parallel_size=args.gpus, trust_remote_code=True, ) self.tokenizer = self.llm.get_tokenizer() def generate_response(self, prompts, max_tokens=1024, temperature=0.01, top_p=0.5): sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_tokens, skip_special_tokens=True, stop=self.tokenizer.eos_token) if self.tag in ['mistral']: messages_list = [[{"role": "user", "content": p}] for p in prompts] else: messages_list = [[{"role": "system", "content": "You are a helpful SQLite assistant."},{"role": "user", "content": p}] for p in prompts] messages_list = self.tokenizer.apply_chat_template(messages_list, add_generation_prompt=True,tokenize=False) outputs = self.llm.generate(messages_list, sampling_params) return [output.outputs[0].text for output in outputs] class LLM_Online(object): def __init__(self, model= "qwen72b", device = [0]): None def generate_response(self, prompts): rs = [] for prompt in tqdm.tqdm(prompts): res = None # your online LLM rs.append(res) return rs def parse_dataset(data_path, mode = 'dev', dataset = 'bird'): # redirect path data_tuples_path = '' if dataset == 'bird': data_tuples_path = os.path.join(data_path, dataset, mode, f'{mode}.json') elif 'spider_DK' == dataset: data_tuples_path = os.path.join(data_path, 'spider', 'Spider_DK.json') elif 'spider_real' == dataset: data_tuples_path = os.path.join(data_path, 'spider', 'spider-realistic.json') elif 'spider' in dataset: if mode == 'test': data_tuples_path = os.path.join(data_path, 'spider','test_data/dev.json') else: data_tuples_path = os.path.join(data_path, 'spider', f'{mode}.json') else: raise TypeError(f"Unexpect dataset: {dataset}.") data_tuples = read_json_file(data_tuples_path) return data_tuples def convert_fk_index(data): fk_holder = [] table_names_original = [i.lower() for i in data['table_names_original']] # some bug column_names_original = [(i[0], i[1].lower()) for i in data['column_names_original']] for fk in data["foreign_keys"]: tn, col, ref_tn, ref_col = fk[0][0], fk[0][1], fk[1][0], fk[1][1] if type(tn) is str: tn = tn.lower() if type(col) is str: col = col.lower() if type(ref_tn) is str: ref_tn = ref_tn.lower() if type(ref_col) is str: ref_col = ref_col.lower() ref_cid, cid = None, None try: tid =table_names_original.index(tn) ref_tid = table_names_original.index(ref_tn) for i, (tab_id, col_org) in enumerate(column_names_original): if tab_id == ref_tid and ref_col == col_org: ref_cid = i elif tid == tab_id and col == col_org: cid = i if ref_cid and cid: fk_holder.append([cid, ref_cid]) except: traceback.print_exc() print("table_names_original: ", table_names_original) print("finding tab name: ", tn, ref_tn) print(data) # sys.exit() return fk_holder def dump_db_json_schema(db, f): '''read table and column info''' try: conn = sqlite3.connect(db) except: print(db) exit() conn.execute('pragma foreign_keys=ON') cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table';") data = {'db_id': f, 'table_names_original': [], 'table_names': [], 'column_names_original': [(-1, '*')], 'column_names': [(-1, '*')], 'column_types': ['text'], 'primary_keys': [], 'foreign_keys': []} fk_holder = [] for i, item in enumerate(cursor.fetchall()): table_name = item[0] data['table_names_original'].append(table_name) data['table_names'].append(table_name.lower().replace("_", ' ')) fks = conn.execute("PRAGMA foreign_key_list('{}') ".format(table_name)).fetchall() #print("db:{} table:{} fks:{}".format(f,table_name,fks)) fk_holder.extend([[(table_name, fk[3]), (fk[2], fk[4])] for fk in fks]) cur = conn.execute("PRAGMA table_info('{}') ".format(table_name)) for j, col in enumerate(cur.fetchall()): data['column_names_original'].append((i, col[1])) data['column_names'].append((i, col[1].lower().replace("_", " "))) #varchar, '' -> text, int, numeric -> integer, col_type = col[2].lower() if 'char' in col_type or col_type == '' or 'text' in col_type or 'var' in col_type: data['column_types'].append('text') elif 'int' in col_type or 'numeric' in col_type or 'decimal' in col_type or 'number' in col_type\ or 'id' in col_type or 'real' in col_type or 'double' in col_type or 'float' in col_type: data['column_types'].append('number') elif 'date' in col_type or 'time' in col_type or 'year' in col_type: data['column_types'].append('time') elif 'boolean' in col_type: data['column_types'].append('boolean') else: data['column_types'].append('others') if col[5] == 1: data['primary_keys'].append(len(data['column_names'])-1) data["foreign_keys"] = fk_holder data['foreign_keys'] = convert_fk_index(data) return data def get_schema_dict(db, kk=3): """ Get database's schema, which is a dict with table name as key and list of column names as value :param db: database path :return: schema dict """ data = dump_db_json_schema(db,db.split('/')[-1]) tables = data['table_names_original'] column_types = data['column_types'] primary_keys = data['primary_keys'] foreign_keys = data['foreign_keys'] column_names = data['column_names_original'] schema_dict = { 'tables': {}, 'foreign_keys':[] } for i, table in enumerate(tables): t = {} for j, c in enumerate(column_names): if c[0] == i: if j in primary_keys: t[c[1]] = [column_types[j].upper(), True] else: t[c[1]] = [column_types[j].upper(), True] schema_dict['tables'][table] = t for foreign_key in foreign_keys: t1 = tables[column_names[foreign_key[0]][0]] c1 = column_names[foreign_key[0]][1] t2 = tables[column_names[foreign_key[1]][0]] c2 = column_names[foreign_key[1]][1] schema_dict['foreign_keys'].append([t1,c1,t2,c2]) conn = sqlite3.connect(db) cursor = conn.cursor() # get exapmles for table in schema_dict['tables'].keys(): try: select_query = f'SELECT * FROM `{table}` LIMIT {kk}' cursor.execute(select_query) rows = cursor.fetchall() cursor.execute(f"PRAGMA table_info(`{table}`);") columns = [column[1] for column in cursor.fetchall() ] for i, c in enumerate(columns): cls_valuse = [f"{row[i][0:100]}..." if type(row[i]) is str and len(row[i]) > 100 else row[i] for row in rows] schema_dict['tables'][table][c].append(cls_valuse) except Exception as e: print(e) return schema_dict def get_example_str(schema_dict, k=1): tables = list(schema_dict['tables'].keys()) examples = {} for table in tables: table_dict = schema_dict['tables'][table] example = [] for cls in table_dict.keys(): example.append(table_dict[cls][2]) example_str = [] for i, v in enumerate(example[0]): example_str.append(tuple([e[i] for e in example])) if (i+1) == k: break examples[table] = example_str e_s = '' for key in examples.keys(): e_s += f"{key}: " + str(examples[key])+'\n' return e_s[:-1] def get_schmea_str_and_examples(schema_dict): schmea_str = "" tables = list(schema_dict['tables'].keys()) examples = {} for table in tables: if ' ' in table: table_str = f'CREATE TABLE "{table}" (' else: table_str = f"CREATE TABLE {table} (" table_dict = schema_dict['tables'][table] pk_str = '' example = [] for cls in table_dict.keys(): try: cls_ = f'"{cls}"' if ' ' in cls else cls table_str += f"{cls_} {table_dict[cls][0]}, " if table_dict[cls][1]: pk_str += cls_+', ' example.append(table_dict[cls][2]) except Exception as e: print(e) example_str = [] try: for i, v in enumerate(example[0]): example_str.append(tuple([e[i] for e in example])) except Exception as e: print(e) examples[table] = example_str if pk_str != '': table_str += f"PRIMARY KEY({pk_str[:-2]}), " fk_str = '' for fk in schema_dict['foreign_keys']: if fk[0] == table and fk[2] in tables: if fk[3] in schema_dict['tables'][fk[2]].keys(): fk = [f'"{f}"' if ' ' in f else f for f in fk ] fk_str += f'FOREIGN KEY ({fk[1]}) REFERENCES {fk[2]}({fk[3]}), ' if fk_str != '': table_str += fk_str schmea_str += table_str[:-2] +'); ' schmea_str = schmea_str[:-1] e_s = '' for key in examples.keys(): e_s += f"{key}: " + str(examples[key])+'\n' return schmea_str, e_s[:-1] # parse SQL def parse_sql_from_string(input_string): input_string = input_string.replace('\n', ' ').replace('\t','') rs = '' if '```sql' in input_string: try: sql_pattern = r'```sql(.*?)```' all_sqls = [] for match in re.finditer(sql_pattern, input_string, re.DOTALL): all_sqls.append(match.group(1).strip()) if all_sqls: rs = all_sqls[-1] if 'SELECT' not in rs and len(all_sqls)>1: rs = all_sqls[-2] except: None if 'select' in input_string.lower() and rs=='': rs = input_string[input_string.find('SELECT'):] if ';' in rs: # end rs = rs[:input_string.find(';')+1] if rs == '': rs = 'SELECT xx FROM xx' return replace_multiple_spaces(rs).replace('```','') def replace_multiple_spaces(text): return re.sub(r'\s{2,}', ' ', text) def filter_dict_by_sql(schema_dict, sql): schema_dict_ = copy.deepcopy(schema_dict) keys = list(schema_dict_['tables'].keys()) keys.sort(key=lambda x: - len(x)) # tables for table in keys: if f'from {table.lower()}' not in sql.lower() and f'join {table.lower()}' not in sql.lower(): schema_dict_['tables'].pop(table, None) # columns keys = list(schema_dict_['tables'].keys()) keys.sort(key=lambda x: - len(x)) for table in keys: cls_keys = list(schema_dict_['tables'][table].keys()) cls_keys.sort(key=lambda x: - len(x)) tabel_dict = copy.deepcopy(schema_dict_['tables'][table]) for cls in cls_keys: if cls.lower() not in sql.lower(): schema_dict_['tables'][table].pop(cls, None) if len(schema_dict_['tables'][table].keys()) == 0: # schema_dict_['tables'][table] = tabel_dict # for COUNT(*) for cls in tabel_dict.keys(): if tabel_dict[cls][1] == True: schema_dict_['tables'][table][cls] = tabel_dict[cls] if len(schema_dict_['tables'][table].keys()) == 0: schema_dict_['tables'][table][tabel_dict.keys()[0]] = tabel_dict[tabel_dict.keys()[0]] schema_dict_['tables'][table][tabel_dict.keys()[1]] = tabel_dict[tabel_dict.keys()[1]] # for COUNT(*) return schema_dict_ def filter_dict_by_sl(schema_dict, sql): schema_dict_ = copy.deepcopy(schema_dict) keys = list(schema_dict_['tables'].keys()) keys.sort(key=lambda x: - len(x)) # tables for table in keys: if f'{table.lower()}' not in sql.lower(): schema_dict_['tables'].pop(table, None) # columns keys = list(schema_dict_['tables'].keys()) keys.sort(key=lambda x: - len(x)) for table in keys: cls_keys = list(schema_dict_['tables'][table].keys()) cls_keys.sort(key=lambda x: - len(x)) tabel_dict = copy.deepcopy(schema_dict_['tables'][table]) for cls in cls_keys: if cls.lower() not in sql.lower(): schema_dict_['tables'][table].pop(cls, None) if len(schema_dict_['tables'][table].keys()) == 0: # schema_dict_['tables'][table] = tabel_dict # for COUNT(*) for cls in tabel_dict.keys(): if tabel_dict[cls][1] == True: schema_dict_['tables'][table][cls] = tabel_dict[cls] if len(schema_dict_['tables'][table].keys()) == 0: schema_dict_['tables'][table][tabel_dict.keys()[0]] = tabel_dict[tabel_dict.keys()[0]] schema_dict_['tables'][table][tabel_dict.keys()[1]] = tabel_dict[tabel_dict.keys()[1]] # for COUNT(*) return schema_dict_ @func_set_timeout(5) def execute_query_limit(db_path, query): error = '' result = None conn = sqlite3.connect(db_path, timeout=5.0, check_same_thread=False) cursor = conn.cursor() cursor = conn.cursor() cursor.execute(query) result = cursor.fetchone() cursor.close() conn.close() return result, error def execute_query(db_path, query): try: result, error = execute_query_limit(db_path, query) except func_timeout.exceptions.FunctionTimedOut: error = "SQL execution timeout" print("*"*30, error, query) result = None except Exception as e: error = str(e) print("*"*30, error, query) result = None return result, error def replace_syn(data1, data2): for i in range(len(data1)): if data1[i]['question'] == data2[i]['SpiderQuestion']: data1[i]['question'] = data2[i]['SpiderSynQuestion'] return data1 def eval_all(args): dataset= args.dataset mode=args.mode data_tuples = parse_dataset(args.data_path, mode, dataset) batch_size = args.batch_size if dataset == 'spider_syn': data2 = read_json_file(os.path.join(args.data_path, 'spider', f'{mode}_syn.json')) data_tuples = replace_syn(data_tuples,data2) dataset = 'spider' args.tag += '_syn' if dataset == 'spider_DK': args.tag += '_DK' dataset = 'spider' if dataset == 'spider_real': args.tag += '_real' dataset = 'spider' if dataset == 'bird': kk = 5 else: kk = 10 kkkkk = 1 if dataset=='bird' else 3 if 'online' in args.tag: generator = LLM_Online() else: generator = LLM_Model(args.LLM_model) tag = args.tag flg1 = False flg2 = False flg3 = False flg4 = False old_flgs = args.flags args.flags = args.flags.split('_') if args.flags[0] == '1': flg1 = True if args.flags[1] == '1': flg2 = True if args.flags[2] == '1': flg3 = True if args.flags[3] == '1': flg4 = True # generate SQL if True: sql_results = [] data_header = [["NLQ", "Predict", "GOLD", 'database']] prompts = [] for index, row in enumerate(data_tuples): if 'spider' in dataset: row['SQL'] = row['query'] if 'drspider' in dataset: row['SQL'] = row['query'] question, db_id = row['question'], row['db_id'] if dataset == 'spider': if mode == 'test': db_path = os.path.join(args.data_path, dataset, 'test_database', db_id, f"{db_id}.sqlite") else: db_path = os.path.join(args.data_path, dataset, 'database', db_id, f"{db_id}.sqlite") elif dataset == 'drspider': db_path = os.path.join(args.data_path, db_id, f"{db_id}.sqlite") elif dataset == 'bird': db_path = os.path.join(args.data_path, dataset, f'{mode}/{mode}_databases', db_id, f"{db_id}.sqlite") else: raise TypeError(f"Unexpect dataset: {dataset}.") schema_dict = get_schema_dict(db_path, kk = kk) database_schema, examples = get_schmea_str_and_examples(schema_dict) schema_dict_ = schema_dict if dataset == 'bird': prompt = [question, schema_dict, f"\n\n/* Question hint */\n{row['evidence']}" if row['evidence'] != '' else '', schema_dict_] else: prompt = [question, schema_dict, '', schema_dict_] prompts.append([database_schema, str(examples), question, row['SQL'], db_id, prompt, db_path]) n_samples = len(data_tuples) n_batches = (n_samples - 1)//batch_size + 1 for i in range(n_batches): start = i*batch_size end = n_samples if i== n_batches -1 else (i+1)*batch_size batch_prompts = prompts[start: end] schema_dicts = [] # only keep the tables # schema linking if flg1 or flg2: response_strs = None c_response_strs = None if flg1: if args.eval_sft == 1: c_response_strs = generator.generate_response(prompts=[prompt_sl_temp_sft.format(ds=j[0],sr=get_example_str(j[5][1],kkkkk),qs=j[2],hint=j[5][2]) for j in batch_prompts]) else: c_response_strs = generator.generate_response(prompts=[prompt_sl_temp.format(ds=j[0],sr=get_example_str(j[5][1],kkkkk),qs=j[2],hint=j[5][2]) for j in batch_prompts]) if flg2: response_strs = generator.generate_response(prompts=[prompt_temp.format(ds=j[0],sr=get_example_str(j[5][1],kkkkk),qs=j[2],hint=j[5][2]) for j in batch_prompts]) if c_response_strs is None: c_response_strs = response_strs if response_strs is None: response_strs = c_response_strs for j, response_str in enumerate(c_response_strs): schema_dict = batch_prompts[j][5][1] gt_sql = batch_prompts[j][3] # schema_dict_gt = filter_dict_by_sql(batch_prompts[j][5][1], gt_sql) # sl c_sql_str1 = response_str.replace('"',"'").replace('\'',"") schema_dict_1 = filter_dict_by_sl(batch_prompts[j][5][1], c_sql_str1) # pre-sql c_sql_str2 = parse_sql_from_string(response_strs[j]).replace('"',"'").replace('\'',"") schema_dict_2 = filter_dict_by_sql(batch_prompts[j][5][1], c_sql_str2) schema_dict_old = copy.deepcopy(schema_dict) keys1 = schema_dict_1['tables'].keys() keys2 = schema_dict_2['tables'].keys() all_keys = list(schema_dict_old['tables'].keys()) for key in all_keys: if key not in keys1 and key not in keys2: schema_dict_old['tables'].pop(key, None) else: clss = [] if key in keys1: clss += schema_dict_1['tables'][key].keys() if key in keys2: clss += schema_dict_2['tables'][key].keys() clss = list(set(clss)) for k in list(schema_dict_old['tables'][key].keys()): if k not in clss: schema_dict_old['tables'][key].pop(k,None) if len(schema_dict_old['tables'][key].keys()) == 0: schema_dict_old['tables'].pop(key, None) schema_dict_ = schema_dict_old # schema_dict_ = schema_dict_gt # gt schema_dict_table = copy.deepcopy(schema_dict) for key in schema_dict['tables'].keys(): if key not in schema_dict_['tables'].keys(): schema_dict_table['tables'].pop(key,None) schema_dicts.append(schema_dict_table) if j == 0: print("######", response_str, list(schema_dict_old['tables'].keys()) ) ds, sr = get_schmea_str_and_examples(schema_dict_) batch_prompts[j][0] = ds batch_prompts[j][1] = sr else: for j, v in enumerate(batch_prompts): batch_prompts[j][1] = get_example_str(batch_prompts[j][5][1],kkkkk) # text-to-sql final_prompts=[prompt_temp.format(ds=j[0],sr=j[1],qs=j[2],hint=j[5][2]) for j in batch_prompts] response_strs = generator.generate_response(prompts=final_prompts) def contains_subquery(sql_query, tables): sql = sql_query.lower() select_num = 0 join_num = 0 tmp = sql while 'select' in tmp: tmp = tmp[tmp.find('select')+6:] select_num += 1 tmp = sql while 'join' in tmp: tmp = tmp[tmp.find('select')+6:] join_num += 1 table_num = len([key for key in tables if f"from {key.lower()}" in sql or f"join {key.lower()}" in sql]) if table_num == 1: hard = 1 elif table_num==2: hard = 2 else: hard = 3 return hard nc_idx = [] continue_sqls = [] # noisy correction if flg3: predSQLs = [parse_sql_from_string(response_str) for response_str in response_strs] nc_prompts = [] for j in range(len(response_strs)): v = batch_prompts[j] predSQL = predSQLs[j] ds = get_schmea_str_and_examples(v[5][1])[0] sr = get_example_str(v[5][1],kkkkk) ex_hint = execute_query(batch_prompts[j][6], predSQL)[1] if ex_hint != '': ex_hint = f"\n\n/* Execution exception */\n{ex_hint}" # ex_hint = '' if args.eval_sft == 1: nc_prompts.append(prompt_nc_temp_sft.format(ds=ds ,sr=sr, qs=v[2], ex_hint = ex_hint, hint=v[5][2],sql = predSQL)) else: nc_prompts.append(prompt_nc_temp.format(ds=ds ,sr=sr, qs=v[2], ex_hint = ex_hint, hint=v[5][2],sql = predSQL)) response_strs_ = generator.generate_response(prompts=nc_prompts) for idx, v in enumerate(response_strs_): if idx == 0: print("******", nc_prompts[0], '\n', v, batch_prompts[idx][3]) v_lower = v.lower() v_lower = v_lower[:v_lower.find('select')+6] if 'select' in v_lower else v_lower flag = 'select' in v_lower and ('can correctly answer' not in v_lower and 'can answer correctly' not in v_lower and 'is correct' not in v_lower and 'will answer correctly' not in v_lower and 'will correctly answer' not in v_lower and 'can answer ' not in v_lower and 'can accurately answer' not in v_lower and 'can answer accurately' not in v_lower and 'is correct' not in v_lower and 'will answer accurately' not in v_lower and 'will accurately answer' not in v_lower and 'can answer ' not in v_lower ) pre_sql = parse_sql_from_string(response_strs[idx]) if flag: ex_flg2 = True if execute_query(batch_prompts[idx][6], parse_sql_from_string(v))[1] == '' else False # if ex_flg2: if ex_flg2: response_strs[idx] = v pre_sql = parse_sql_from_string(response_strs[idx]) ex_flg3 = True if execute_query(batch_prompts[idx][6], pre_sql)[1] == '' else False hard = contains_subquery(pre_sql, batch_prompts[idx][5][1]['tables'].keys()) if ex_flg3 == False or hard > 2: common_sql = 'SELECT ' continue_sqls.append(common_sql) nc_idx.append(idx) else: for idx, v in enumerate(response_strs): pre_sql = parse_sql_from_string(response_strs[idx]) ex_flg3 = True if execute_query(batch_prompts[idx][6], pre_sql)[1] == '' else False hard = contains_subquery(pre_sql, batch_prompts[idx][5][1]['tables'].keys()) if ex_flg3 == False or hard > 2: common_sql = 'SELECT ' continue_sqls.append(common_sql) nc_idx.append(idx) # continuation writing if flg4: cl_prompts = [] for j, idx in enumerate(nc_idx): v = batch_prompts[idx] ds = get_schmea_str_and_examples(v[5][1])[0] sr = get_example_str(v[5][1],kkkkk) common_sql = continue_sqls[j] if args.eval_sft == 1: cl_prompts.append(prompt_cw_temp_sft.format(ds=ds, sr=sr, qs=v[2],hint=v[5][2], sql = common_sql)) else: cl_prompts.append(prompt_cw_temp.format(ds=ds, sr=sr, qs=v[2],hint=v[5][2], sql = common_sql)) if len(nc_idx) > 0: response_strs_ = generator.generate_response(prompts=cl_prompts) print("%%%%%%%%%%%%%%%%%%",response_strs_[0]) for idx, v in enumerate(nc_idx): if execute_query(batch_prompts[v][6], parse_sql_from_string(response_strs_[idx]))[0] is not None: response_strs[v] = response_strs_[idx] for j, response_str in enumerate(response_strs): database_schema = batch_prompts[j][0] question = batch_prompts[j][2] gt_sql = replace_multiple_spaces(batch_prompts[j][3]) db_id = batch_prompts[j][4] prompt = final_prompts[j] print(f"=={start+j+1}/{len(data_tuples)}=={db_id}=={tag}==================") try: if dataset == 'spider': if mode == 'test': db_path = os.path.join(args.data_path, dataset, 'test_database', db_id, f"{db_id}.sqlite") else: db_path = os.path.join(args.data_path, dataset, 'database', db_id, f"{db_id}.sqlite") elif dataset == 'bird': db_path = os.path.join(args.data_path, dataset, f'{mode}/{mode}_databases', db_id, f"{db_id}.sqlite") else: raise TypeError(f"Unexpect dataset: {dataset}.") SQL_str = parse_sql_from_string(response_str) except Exception as e: res = f'error: {str(e)}' print(res, response_str) sql_results.append([question, SQL_str, gt_sql, db_id]) print(prompt) print(f"Question: {question}") print(f"Raw Resp: {response_str}") print(f"Answer: {SQL_str}") print(f"Ground: {gt_sql}") if SQL_str== 'None': exit() if not os.path.isdir(os.path.join(args.output_path, f"{tag}_{dataset}")): os.makedirs(os.path.join(args.output_path, f"{tag}_{dataset}")) with open(os.path.join(args.output_path, f"{tag}_{dataset}", f"rs_{old_flgs}.csv"), mode='w', newline='', encoding='utf-8') as file: writer = csv.writer(file) writer.writerows(data_header + sql_results) import os import pynvml pynvml.nvmlInit() def usegpu(need_gpu_count=1): nouse=[] for index in range(pynvml.nvmlDeviceGetCount()): # 这里的0是GPU id handle = pynvml.nvmlDeviceGetHandleByIndex(index) meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) used= meminfo.used/meminfo.total if used<0.3: nouse.append(index) if len(nouse)>=need_gpu_count: os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, nouse[:need_gpu_count])) # return nouse[:need_gpu_count] print(nouse[:need_gpu_count]) return need_gpu_count elif len(nouse)>0: os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, nouse)) return len(nouse) else: return 0 if __name__ == "__main__": parser = argparse.ArgumentParser(description='SQL') parser.add_argument("--dataset", default='spider', type=str) parser.add_argument("--data_path", default='./dataset', type=str) parser.add_argument("--output_path", default='./dataset', type=str) parser.add_argument("--mode", default='dev', type=str) parser.add_argument("--tag", default='0701', type=str) parser.add_argument("--gpus", default=0, type=int) parser.add_argument("--eval_sft", default=1, type=int) parser.add_argument("--flags", default='1_0_0', type=str) parser.add_argument("--LLM_model", default='/disk2/qinyang/qwen2-1.5b-instruct', type=str) parser.add_argument("--batch_size", default=32, type=int) args = parser.parse_args() usegpu(need_gpu_count=args.gpus) print(args) eval_all(args)