in E2E_TOD/dataclass.py [0:0]
def build_all_evaluation_batch_list(self, ref_bs, ref_act, ref_db, input_contain_db, eva_batch_size, eva_mode):
'''
ref_bs: whether using reference belief state to perform generation
if with reference belief state, then we also use reference db result
else generating belief state to query the db
ref_act: whether using reference dialogue action to perform generation
if true: it always means that we also use reference belief state
if false: we can either use generated belief state and queried db result or
use reference belief state and reference db result
eva_mode: 'dev' or 'test'; perform evaluation either on dev set or test set
eva_batch_size: size of each evaluated batch
'''
if eva_mode == 'dev':
data_list = self.dev_data_list
eva_num = self.dev_num
elif eva_mode == 'test':
data_list = self.test_data_list
eva_num = self.test_num
else:
raise Exception('Wrong Evaluation Mode!!!')
all_bs_input_id_list, all_da_input_id_list, all_nlg_input_id_list, all_parse_dict_list = \
[], [], [], []
for item in data_list:
one_bs_input_id_list, one_da_input_id_list, one_nlg_input_id_list, one_parse_dict = \
self.parse_one_eva_instance(item, ref_bs, ref_act, ref_db, input_contain_db)
all_bs_input_id_list.append(one_bs_input_id_list)
all_da_input_id_list.append(one_da_input_id_list)
all_nlg_input_id_list.append(one_nlg_input_id_list)
all_parse_dict_list.append(one_parse_dict)
assert len(all_bs_input_id_list) == len(all_da_input_id_list)
assert len(all_da_input_id_list) == len(all_nlg_input_id_list)
assert len(all_nlg_input_id_list) == len(all_parse_dict_list)
bs_batch_list = self.build_batch_list(all_bs_input_id_list, eva_batch_size)
da_batch_list = self.build_batch_list(all_da_input_id_list, eva_batch_size)
nlg_batch_list = self.build_batch_list(all_nlg_input_id_list, eva_batch_size)
parse_dict_batch_list = self.build_batch_list(all_parse_dict_list, eva_batch_size)
batch_num = len(bs_batch_list)
final_batch_list = []
for idx in range(batch_num):
one_final_batch = [bs_batch_list[idx], da_batch_list[idx], nlg_batch_list[idx], parse_dict_batch_list[idx]]
if len(bs_batch_list[idx]) == 0:
continue
else:
final_batch_list.append(one_final_batch)
return final_batch_list