def eval_all()

in route.py [0:0]


def eval_all(args): 
    dataset= args.dataset
    mode=args.mode 
    data_tuples = parse_dataset(args.data_path, mode, dataset)
    batch_size = args.batch_size
    
    if dataset == 'spider_syn': 
        data2 = read_json_file(os.path.join(args.data_path, 'spider', f'{mode}_syn.json'))
        data_tuples = replace_syn(data_tuples,data2)
        dataset = 'spider'
        args.tag += '_syn'
    
    if dataset == 'spider_DK': 
        args.tag += '_DK'
        dataset = 'spider' 
 
    if dataset == 'spider_real': 
        args.tag += '_real'
        dataset = 'spider'

    if dataset == 'bird':
        kk = 5
    else:
        kk = 10
    kkkkk = 1 if dataset=='bird' else 3

    if 'online' in args.tag:
        generator = LLM_Online()
    else:
        generator = LLM_Model(args.LLM_model)
    tag =  args.tag 
    
    flg1 = False
    flg2 = False
    flg3 = False
    flg4 = False
    old_flgs = args.flags
    args.flags =  args.flags.split('_')
    if args.flags[0] == '1':
        flg1 = True 
    if args.flags[1] == '1':
        flg2 = True
    if args.flags[2] == '1':
        flg3 = True
    if args.flags[3] == '1':
        flg4 = True

    # generate SQL
    if True:
        sql_results = []
        data_header = [["NLQ", "Predict", "GOLD", 'database']] 
        prompts = []
        for index, row in enumerate(data_tuples):
            if 'spider' in dataset: 
                row['SQL'] = row['query'] 
            if 'drspider' in dataset: 
                row['SQL'] = row['query']

            question, db_id = row['question'],  row['db_id']
            if dataset == 'spider':
                if mode == 'test':
                    db_path =  os.path.join(args.data_path, dataset, 'test_database', db_id, f"{db_id}.sqlite")
                else:
                    db_path =  os.path.join(args.data_path, dataset, 'database', db_id, f"{db_id}.sqlite")  
            elif dataset == 'drspider':
                db_path =  os.path.join(args.data_path, db_id, f"{db_id}.sqlite") 
            elif dataset == 'bird': 
                db_path =  os.path.join(args.data_path, dataset, f'{mode}/{mode}_databases', db_id, f"{db_id}.sqlite")    
            else:
                raise TypeError(f"Unexpect dataset: {dataset}.")     
            
            schema_dict = get_schema_dict(db_path, kk = kk) 
            database_schema, examples = get_schmea_str_and_examples(schema_dict) 
            schema_dict_ = schema_dict
            
            if dataset == 'bird':
                prompt = [question, schema_dict, f"\n\n/* Question hint */\n{row['evidence']}" if row['evidence'] != '' else '', schema_dict_]
            else:
                prompt = [question, schema_dict, '', schema_dict_]
            prompts.append([database_schema, str(examples), question, row['SQL'], db_id, prompt, db_path])
       
        n_samples = len(data_tuples)
        n_batches = (n_samples - 1)//batch_size + 1

        for i in range(n_batches): 
            start = i*batch_size
            end = n_samples if i== n_batches -1 else (i+1)*batch_size
            batch_prompts = prompts[start: end]
            schema_dicts = [] # only keep the tables
            # schema linking
            if flg1 or flg2:
                response_strs = None
                c_response_strs = None
                if flg1:
                    if args.eval_sft == 1:
                        c_response_strs = generator.generate_response(prompts=[prompt_sl_temp_sft.format(ds=j[0],sr=get_example_str(j[5][1],kkkkk),qs=j[2],hint=j[5][2]) for j in 
                    batch_prompts])
                    else:
                        c_response_strs = generator.generate_response(prompts=[prompt_sl_temp.format(ds=j[0],sr=get_example_str(j[5][1],kkkkk),qs=j[2],hint=j[5][2]) for j in 
                    batch_prompts])
                if flg2:
                    response_strs = generator.generate_response(prompts=[prompt_temp.format(ds=j[0],sr=get_example_str(j[5][1],kkkkk),qs=j[2],hint=j[5][2]) for j in batch_prompts])
 
                if c_response_strs is None:
                    c_response_strs = response_strs
                if response_strs is None:
                    response_strs = c_response_strs
            
                for j, response_str in enumerate(c_response_strs):
                    schema_dict = batch_prompts[j][5][1]
                    gt_sql = batch_prompts[j][3]
                    # schema_dict_gt = filter_dict_by_sql(batch_prompts[j][5][1], gt_sql) 

                    # sl
                    c_sql_str1 = response_str.replace('"',"'").replace('\'',"")  
                    schema_dict_1 = filter_dict_by_sl(batch_prompts[j][5][1], c_sql_str1) 

                    # pre-sql
                    c_sql_str2 = parse_sql_from_string(response_strs[j]).replace('"',"'").replace('\'',"")  
                    schema_dict_2 = filter_dict_by_sql(batch_prompts[j][5][1], c_sql_str2)

                    schema_dict_old = copy.deepcopy(schema_dict)
                    keys1 = schema_dict_1['tables'].keys()
                    keys2 = schema_dict_2['tables'].keys()
                    all_keys = list(schema_dict_old['tables'].keys())
                    for key in all_keys:
                        if key not in keys1 and key not in keys2:
                            schema_dict_old['tables'].pop(key, None)
                        else:
                            clss = []  
                            if key in keys1:
                                clss += schema_dict_1['tables'][key].keys() 
                            if key in keys2:
                                clss += schema_dict_2['tables'][key].keys() 
                            clss = list(set(clss))
 
                            for k in list(schema_dict_old['tables'][key].keys()):
                                if k not in clss:
                                    schema_dict_old['tables'][key].pop(k,None)
                            if len(schema_dict_old['tables'][key].keys()) == 0:
                                schema_dict_old['tables'].pop(key, None)

                    schema_dict_ = schema_dict_old
                    # schema_dict_ = schema_dict_gt # gt

                    schema_dict_table = copy.deepcopy(schema_dict) 
                    for key in schema_dict['tables'].keys():
                        if key not in schema_dict_['tables'].keys():
                            schema_dict_table['tables'].pop(key,None) 
                    schema_dicts.append(schema_dict_table)

                    if j == 0:
                        print("######", response_str, list(schema_dict_old['tables'].keys()) )

                    ds, sr = get_schmea_str_and_examples(schema_dict_)
                    batch_prompts[j][0] = ds
                    batch_prompts[j][1] = sr   
            else:
                for j, v in enumerate(batch_prompts):
                    batch_prompts[j][1] = get_example_str(batch_prompts[j][5][1],kkkkk) 

            # text-to-sql
            final_prompts=[prompt_temp.format(ds=j[0],sr=j[1],qs=j[2],hint=j[5][2]) for j in batch_prompts]    
            response_strs = generator.generate_response(prompts=final_prompts)
 
            def contains_subquery(sql_query, tables): 
                sql = sql_query.lower()
                select_num = 0
                join_num = 0
                tmp = sql
                while 'select' in tmp:
                    tmp = tmp[tmp.find('select')+6:]
                    select_num += 1
                tmp = sql
                while 'join' in tmp:
                    tmp = tmp[tmp.find('select')+6:]
                    join_num += 1
                table_num = len([key for key in  tables if f"from {key.lower()}" in sql or f"join {key.lower()}" in sql])
                if table_num == 1:
                    hard = 1
                elif table_num==2:
                    hard = 2
                else:
                    hard = 3 
                return hard
            
            nc_idx = [] 
            continue_sqls = []
            # noisy correction
            if flg3:
                predSQLs = [parse_sql_from_string(response_str) for response_str in response_strs]
                nc_prompts = []
                for j in range(len(response_strs)):  
                    v = batch_prompts[j]
                    predSQL = predSQLs[j]
                    ds = get_schmea_str_and_examples(v[5][1])[0] 
                    sr = get_example_str(v[5][1],kkkkk) 
                    ex_hint = execute_query(batch_prompts[j][6], predSQL)[1]
                    if ex_hint != '':
                        ex_hint = f"\n\n/* Execution exception */\n{ex_hint}" 
                    # ex_hint = ''
                    if args.eval_sft == 1:
                        nc_prompts.append(prompt_nc_temp_sft.format(ds=ds ,sr=sr, qs=v[2], ex_hint = ex_hint, hint=v[5][2],sql = predSQL))
                    else:
                        nc_prompts.append(prompt_nc_temp.format(ds=ds ,sr=sr, qs=v[2], ex_hint = ex_hint, hint=v[5][2],sql = predSQL))

                response_strs_ = generator.generate_response(prompts=nc_prompts)
                for idx, v in enumerate(response_strs_): 
                    if idx == 0:
                        print("******", nc_prompts[0], '\n', v, batch_prompts[idx][3])  
                    v_lower = v.lower()
                    v_lower = v_lower[:v_lower.find('select')+6] if 'select' in v_lower else v_lower
                    flag = 'select' in v_lower and ('can correctly answer' not in v_lower 
                                                    and 'can answer correctly' not in v_lower 
                                                    and 'is correct' not in v_lower
                                                    and 'will answer correctly' not in v_lower
                                                    and 'will correctly answer' not in v_lower
                                                    and 'can answer ' not in v_lower
                                                    and 'can accurately answer' not in v_lower 
                                                    and 'can answer accurately' not in v_lower 
                                                    and 'is correct' not in v_lower
                                                    and 'will answer accurately' not in v_lower
                                                    and 'will accurately answer' not in v_lower
                                                    and 'can answer ' not in v_lower )
                    
                    pre_sql = parse_sql_from_string(response_strs[idx])
                    if flag:
                        ex_flg2 = True if execute_query(batch_prompts[idx][6], parse_sql_from_string(v))[1] == '' else False
                        # if ex_flg2:
                        if ex_flg2:
                            response_strs[idx] =  v

                    pre_sql = parse_sql_from_string(response_strs[idx])
                    ex_flg3 = True if execute_query(batch_prompts[idx][6], pre_sql)[1] == '' else False 
                    hard = contains_subquery(pre_sql, batch_prompts[idx][5][1]['tables'].keys()) 
                    if  ex_flg3 == False or hard > 2:
                        common_sql = 'SELECT '
                        continue_sqls.append(common_sql)  
                        nc_idx.append(idx)
 
            else:
                for idx, v in enumerate(response_strs): 
                    pre_sql = parse_sql_from_string(response_strs[idx])
                    ex_flg3 = True if execute_query(batch_prompts[idx][6], pre_sql)[1] == '' else False 
                    hard = contains_subquery(pre_sql, batch_prompts[idx][5][1]['tables'].keys())
                    if ex_flg3 == False or hard > 2:  
                        common_sql = 'SELECT '
                        continue_sqls.append(common_sql)  
                        nc_idx.append(idx)
 

            # continuation writing
            if flg4:
                cl_prompts = []
                for j, idx in enumerate(nc_idx):
                    v = batch_prompts[idx]
                    ds = get_schmea_str_and_examples(v[5][1])[0] 
                    sr = get_example_str(v[5][1],kkkkk) 
                    common_sql = continue_sqls[j]
                    if args.eval_sft == 1:
                        cl_prompts.append(prompt_cw_temp_sft.format(ds=ds, sr=sr, qs=v[2],hint=v[5][2], sql = common_sql))
                    else:
                        cl_prompts.append(prompt_cw_temp.format(ds=ds, sr=sr, qs=v[2],hint=v[5][2], sql = common_sql))

                if len(nc_idx) > 0:
                    response_strs_ = generator.generate_response(prompts=cl_prompts)  
                    print("%%%%%%%%%%%%%%%%%%",response_strs_[0])
                    for idx, v in enumerate(nc_idx): 
                        if execute_query(batch_prompts[v][6], parse_sql_from_string(response_strs_[idx]))[0] is not None:
                            response_strs[v] = response_strs_[idx]  

            for j, response_str in enumerate(response_strs):
                database_schema = batch_prompts[j][0]
                question = batch_prompts[j][2]
                gt_sql = replace_multiple_spaces(batch_prompts[j][3])
                db_id = batch_prompts[j][4]
                prompt = final_prompts[j]
                print(f"=={start+j+1}/{len(data_tuples)}=={db_id}=={tag}==================")
                try:
                    if dataset == 'spider':
                        if mode == 'test':
                            db_path =  os.path.join(args.data_path, dataset, 'test_database', db_id, f"{db_id}.sqlite")
                        else:
                            db_path =  os.path.join(args.data_path, dataset, 'database', db_id, f"{db_id}.sqlite")   
                    elif dataset == 'bird': 
                        db_path =  os.path.join(args.data_path, dataset, f'{mode}/{mode}_databases', db_id, f"{db_id}.sqlite")    
                    else:
                        raise TypeError(f"Unexpect dataset: {dataset}.")   
                    
                    SQL_str = parse_sql_from_string(response_str)
                except Exception as e:
                    res = f'error: {str(e)}'
                    print(res, response_str) 
                
                sql_results.append([question, SQL_str, gt_sql, db_id])

                print(prompt)
                print(f"Question: {question}")
                print(f"Raw Resp: {response_str}")
                print(f"Answer: {SQL_str}")
                print(f"Ground: {gt_sql}")

                if SQL_str== 'None':
                    exit()

                if not os.path.isdir(os.path.join(args.output_path, f"{tag}_{dataset}")):
                    os.makedirs(os.path.join(args.output_path, f"{tag}_{dataset}"))

                with open(os.path.join(args.output_path, f"{tag}_{dataset}", f"rs_{old_flgs}.csv"), mode='w', newline='',  encoding='utf-8') as file:
                    writer = csv.writer(file)
                    writer.writerows(data_header + sql_results)