in route.py [0:0]
def eval_all(args):
dataset= args.dataset
mode=args.mode
data_tuples = parse_dataset(args.data_path, mode, dataset)
batch_size = args.batch_size
if dataset == 'spider_syn':
data2 = read_json_file(os.path.join(args.data_path, 'spider', f'{mode}_syn.json'))
data_tuples = replace_syn(data_tuples,data2)
dataset = 'spider'
args.tag += '_syn'
if dataset == 'spider_DK':
args.tag += '_DK'
dataset = 'spider'
if dataset == 'spider_real':
args.tag += '_real'
dataset = 'spider'
if dataset == 'bird':
kk = 5
else:
kk = 10
kkkkk = 1 if dataset=='bird' else 3
if 'online' in args.tag:
generator = LLM_Online()
else:
generator = LLM_Model(args.LLM_model)
tag = args.tag
flg1 = False
flg2 = False
flg3 = False
flg4 = False
old_flgs = args.flags
args.flags = args.flags.split('_')
if args.flags[0] == '1':
flg1 = True
if args.flags[1] == '1':
flg2 = True
if args.flags[2] == '1':
flg3 = True
if args.flags[3] == '1':
flg4 = True
# generate SQL
if True:
sql_results = []
data_header = [["NLQ", "Predict", "GOLD", 'database']]
prompts = []
for index, row in enumerate(data_tuples):
if 'spider' in dataset:
row['SQL'] = row['query']
if 'drspider' in dataset:
row['SQL'] = row['query']
question, db_id = row['question'], row['db_id']
if dataset == 'spider':
if mode == 'test':
db_path = os.path.join(args.data_path, dataset, 'test_database', db_id, f"{db_id}.sqlite")
else:
db_path = os.path.join(args.data_path, dataset, 'database', db_id, f"{db_id}.sqlite")
elif dataset == 'drspider':
db_path = os.path.join(args.data_path, db_id, f"{db_id}.sqlite")
elif dataset == 'bird':
db_path = os.path.join(args.data_path, dataset, f'{mode}/{mode}_databases', db_id, f"{db_id}.sqlite")
else:
raise TypeError(f"Unexpect dataset: {dataset}.")
schema_dict = get_schema_dict(db_path, kk = kk)
database_schema, examples = get_schmea_str_and_examples(schema_dict)
schema_dict_ = schema_dict
if dataset == 'bird':
prompt = [question, schema_dict, f"\n\n/* Question hint */\n{row['evidence']}" if row['evidence'] != '' else '', schema_dict_]
else:
prompt = [question, schema_dict, '', schema_dict_]
prompts.append([database_schema, str(examples), question, row['SQL'], db_id, prompt, db_path])
n_samples = len(data_tuples)
n_batches = (n_samples - 1)//batch_size + 1
for i in range(n_batches):
start = i*batch_size
end = n_samples if i== n_batches -1 else (i+1)*batch_size
batch_prompts = prompts[start: end]
schema_dicts = [] # only keep the tables
# schema linking
if flg1 or flg2:
response_strs = None
c_response_strs = None
if flg1:
if args.eval_sft == 1:
c_response_strs = generator.generate_response(prompts=[prompt_sl_temp_sft.format(ds=j[0],sr=get_example_str(j[5][1],kkkkk),qs=j[2],hint=j[5][2]) for j in
batch_prompts])
else:
c_response_strs = generator.generate_response(prompts=[prompt_sl_temp.format(ds=j[0],sr=get_example_str(j[5][1],kkkkk),qs=j[2],hint=j[5][2]) for j in
batch_prompts])
if flg2:
response_strs = generator.generate_response(prompts=[prompt_temp.format(ds=j[0],sr=get_example_str(j[5][1],kkkkk),qs=j[2],hint=j[5][2]) for j in batch_prompts])
if c_response_strs is None:
c_response_strs = response_strs
if response_strs is None:
response_strs = c_response_strs
for j, response_str in enumerate(c_response_strs):
schema_dict = batch_prompts[j][5][1]
gt_sql = batch_prompts[j][3]
# schema_dict_gt = filter_dict_by_sql(batch_prompts[j][5][1], gt_sql)
# sl
c_sql_str1 = response_str.replace('"',"'").replace('\'',"")
schema_dict_1 = filter_dict_by_sl(batch_prompts[j][5][1], c_sql_str1)
# pre-sql
c_sql_str2 = parse_sql_from_string(response_strs[j]).replace('"',"'").replace('\'',"")
schema_dict_2 = filter_dict_by_sql(batch_prompts[j][5][1], c_sql_str2)
schema_dict_old = copy.deepcopy(schema_dict)
keys1 = schema_dict_1['tables'].keys()
keys2 = schema_dict_2['tables'].keys()
all_keys = list(schema_dict_old['tables'].keys())
for key in all_keys:
if key not in keys1 and key not in keys2:
schema_dict_old['tables'].pop(key, None)
else:
clss = []
if key in keys1:
clss += schema_dict_1['tables'][key].keys()
if key in keys2:
clss += schema_dict_2['tables'][key].keys()
clss = list(set(clss))
for k in list(schema_dict_old['tables'][key].keys()):
if k not in clss:
schema_dict_old['tables'][key].pop(k,None)
if len(schema_dict_old['tables'][key].keys()) == 0:
schema_dict_old['tables'].pop(key, None)
schema_dict_ = schema_dict_old
# schema_dict_ = schema_dict_gt # gt
schema_dict_table = copy.deepcopy(schema_dict)
for key in schema_dict['tables'].keys():
if key not in schema_dict_['tables'].keys():
schema_dict_table['tables'].pop(key,None)
schema_dicts.append(schema_dict_table)
if j == 0:
print("######", response_str, list(schema_dict_old['tables'].keys()) )
ds, sr = get_schmea_str_and_examples(schema_dict_)
batch_prompts[j][0] = ds
batch_prompts[j][1] = sr
else:
for j, v in enumerate(batch_prompts):
batch_prompts[j][1] = get_example_str(batch_prompts[j][5][1],kkkkk)
# text-to-sql
final_prompts=[prompt_temp.format(ds=j[0],sr=j[1],qs=j[2],hint=j[5][2]) for j in batch_prompts]
response_strs = generator.generate_response(prompts=final_prompts)
def contains_subquery(sql_query, tables):
sql = sql_query.lower()
select_num = 0
join_num = 0
tmp = sql
while 'select' in tmp:
tmp = tmp[tmp.find('select')+6:]
select_num += 1
tmp = sql
while 'join' in tmp:
tmp = tmp[tmp.find('select')+6:]
join_num += 1
table_num = len([key for key in tables if f"from {key.lower()}" in sql or f"join {key.lower()}" in sql])
if table_num == 1:
hard = 1
elif table_num==2:
hard = 2
else:
hard = 3
return hard
nc_idx = []
continue_sqls = []
# noisy correction
if flg3:
predSQLs = [parse_sql_from_string(response_str) for response_str in response_strs]
nc_prompts = []
for j in range(len(response_strs)):
v = batch_prompts[j]
predSQL = predSQLs[j]
ds = get_schmea_str_and_examples(v[5][1])[0]
sr = get_example_str(v[5][1],kkkkk)
ex_hint = execute_query(batch_prompts[j][6], predSQL)[1]
if ex_hint != '':
ex_hint = f"\n\n/* Execution exception */\n{ex_hint}"
# ex_hint = ''
if args.eval_sft == 1:
nc_prompts.append(prompt_nc_temp_sft.format(ds=ds ,sr=sr, qs=v[2], ex_hint = ex_hint, hint=v[5][2],sql = predSQL))
else:
nc_prompts.append(prompt_nc_temp.format(ds=ds ,sr=sr, qs=v[2], ex_hint = ex_hint, hint=v[5][2],sql = predSQL))
response_strs_ = generator.generate_response(prompts=nc_prompts)
for idx, v in enumerate(response_strs_):
if idx == 0:
print("******", nc_prompts[0], '\n', v, batch_prompts[idx][3])
v_lower = v.lower()
v_lower = v_lower[:v_lower.find('select')+6] if 'select' in v_lower else v_lower
flag = 'select' in v_lower and ('can correctly answer' not in v_lower
and 'can answer correctly' not in v_lower
and 'is correct' not in v_lower
and 'will answer correctly' not in v_lower
and 'will correctly answer' not in v_lower
and 'can answer ' not in v_lower
and 'can accurately answer' not in v_lower
and 'can answer accurately' not in v_lower
and 'is correct' not in v_lower
and 'will answer accurately' not in v_lower
and 'will accurately answer' not in v_lower
and 'can answer ' not in v_lower )
pre_sql = parse_sql_from_string(response_strs[idx])
if flag:
ex_flg2 = True if execute_query(batch_prompts[idx][6], parse_sql_from_string(v))[1] == '' else False
# if ex_flg2:
if ex_flg2:
response_strs[idx] = v
pre_sql = parse_sql_from_string(response_strs[idx])
ex_flg3 = True if execute_query(batch_prompts[idx][6], pre_sql)[1] == '' else False
hard = contains_subquery(pre_sql, batch_prompts[idx][5][1]['tables'].keys())
if ex_flg3 == False or hard > 2:
common_sql = 'SELECT '
continue_sqls.append(common_sql)
nc_idx.append(idx)
else:
for idx, v in enumerate(response_strs):
pre_sql = parse_sql_from_string(response_strs[idx])
ex_flg3 = True if execute_query(batch_prompts[idx][6], pre_sql)[1] == '' else False
hard = contains_subquery(pre_sql, batch_prompts[idx][5][1]['tables'].keys())
if ex_flg3 == False or hard > 2:
common_sql = 'SELECT '
continue_sqls.append(common_sql)
nc_idx.append(idx)
# continuation writing
if flg4:
cl_prompts = []
for j, idx in enumerate(nc_idx):
v = batch_prompts[idx]
ds = get_schmea_str_and_examples(v[5][1])[0]
sr = get_example_str(v[5][1],kkkkk)
common_sql = continue_sqls[j]
if args.eval_sft == 1:
cl_prompts.append(prompt_cw_temp_sft.format(ds=ds, sr=sr, qs=v[2],hint=v[5][2], sql = common_sql))
else:
cl_prompts.append(prompt_cw_temp.format(ds=ds, sr=sr, qs=v[2],hint=v[5][2], sql = common_sql))
if len(nc_idx) > 0:
response_strs_ = generator.generate_response(prompts=cl_prompts)
print("%%%%%%%%%%%%%%%%%%",response_strs_[0])
for idx, v in enumerate(nc_idx):
if execute_query(batch_prompts[v][6], parse_sql_from_string(response_strs_[idx]))[0] is not None:
response_strs[v] = response_strs_[idx]
for j, response_str in enumerate(response_strs):
database_schema = batch_prompts[j][0]
question = batch_prompts[j][2]
gt_sql = replace_multiple_spaces(batch_prompts[j][3])
db_id = batch_prompts[j][4]
prompt = final_prompts[j]
print(f"=={start+j+1}/{len(data_tuples)}=={db_id}=={tag}==================")
try:
if dataset == 'spider':
if mode == 'test':
db_path = os.path.join(args.data_path, dataset, 'test_database', db_id, f"{db_id}.sqlite")
else:
db_path = os.path.join(args.data_path, dataset, 'database', db_id, f"{db_id}.sqlite")
elif dataset == 'bird':
db_path = os.path.join(args.data_path, dataset, f'{mode}/{mode}_databases', db_id, f"{db_id}.sqlite")
else:
raise TypeError(f"Unexpect dataset: {dataset}.")
SQL_str = parse_sql_from_string(response_str)
except Exception as e:
res = f'error: {str(e)}'
print(res, response_str)
sql_results.append([question, SQL_str, gt_sql, db_id])
print(prompt)
print(f"Question: {question}")
print(f"Raw Resp: {response_str}")
print(f"Answer: {SQL_str}")
print(f"Ground: {gt_sql}")
if SQL_str== 'None':
exit()
if not os.path.isdir(os.path.join(args.output_path, f"{tag}_{dataset}")):
os.makedirs(os.path.join(args.output_path, f"{tag}_{dataset}"))
with open(os.path.join(args.output_path, f"{tag}_{dataset}", f"rs_{old_flgs}.csv"), mode='w', newline='', encoding='utf-8') as file:
writer = csv.writer(file)
writer.writerows(data_header + sql_results)