in E2E_TOD/eval.py [0:0]
def _get_metric_results(self, data, domain='all', file_list=None):
metric_result = {'domain': domain}
bleu = self.bleu_metric(data, file_list)
if self.cfg.bspn_mode == 'bspn' or self.cfg.enable_dst:
jg, slot_f1, slot_acc, slot_cnt, slot_corr = self.dialog_state_tracking_eval(data, file_list)
jg_nn, sf1_nn, sac_nn, _, _ = self.dialog_state_tracking_eval(data, file_list, no_name=True, no_book=False)
jg_nb, sf1_nb, sac_nb, _, _ = self.dialog_state_tracking_eval(data, file_list, no_name=False, no_book=True)
jg_nnnb, sf1_nnnb, sac_nnnb, _, _ = self.dialog_state_tracking_eval(data, file_list, no_name=True, no_book=True)
metric_result.update({'joint_goal':jg, 'slot_acc': slot_acc, 'slot_f1':slot_f1})
if self.cfg.bspn_mode == 'bsdx':
jg_, slot_f1_, slot_acc_, slot_cnt, slot_corr = self.dialog_state_tracking_eval(data, file_list, bspn_mode='bsdx')
jg_nn_, sf1_nn_, sac_nn_, _, _ = self.dialog_state_tracking_eval(data, file_list, bspn_mode='bsdx', no_name=True, no_book=False)
metric_result.update({'joint_goal_delex':jg_, 'slot_acc_delex': slot_acc_, 'slot_f1_delex':slot_f1_})
info_slots_acc = {}
for slot in slot_cnt:
correct = slot_corr.get(slot, 0)
info_slots_acc[slot] = correct / slot_cnt[slot] * 100
info_slots_acc = OrderedDict(sorted(info_slots_acc.items(), key = lambda x: x[1]))
act_f1 = self.aspn_eval(data, file_list)
avg_act_num, avg_diverse_score = self.multi_act_eval(data, file_list)
accu_single_dom, accu_multi_dom, multi_dom_num = self.domain_eval(data, file_list)
success, match, req_offer_counts, dial_num = self.context_to_response_eval(data, file_list,
same_eval_as_cambridge=cfg.same_eval_as_cambridge)
req_slots_acc = {}
for req in self.requestables:
acc = req_offer_counts[req+'_offer']/(req_offer_counts[req+'_total'] + 1e-10)
req_slots_acc[req] = acc * 100
req_slots_acc = OrderedDict(sorted(req_slots_acc.items(), key = lambda x: x[1]))
if dial_num:
metric_result.update({'act_f1':act_f1,'success':success, 'match':match, 'bleu': bleu,
'req_slots_acc':req_slots_acc, 'info_slots_acc': info_slots_acc,'dial_num': dial_num,
'accu_single_dom': accu_single_dom, 'accu_multi_dom': accu_multi_dom,
'avg_act_num': avg_act_num, 'avg_diverse_score': avg_diverse_score})
if domain == 'all':
logging.info('-------------------------- All DOMAINS --------------------------')
else:
logging.info('-------------------------- %s (# %d) -------------------------- '%(domain.upper(), dial_num))
if self.cfg.bspn_mode == 'bspn' or self.cfg.enable_dst:
logging.info('[DST] joint goal:%2.1f slot acc: %2.1f slot f1: %2.1f act f1: %2.1f'%(jg, slot_acc, slot_f1, act_f1))
logging.info('[DST] [not eval name slots] joint goal:%2.1f slot acc: %2.1f slot f1: %2.1f'%(jg_nn, sac_nn, sf1_nn))
logging.info('[DST] [not eval book slots] joint goal:%2.1f slot acc: %2.1f slot f1: %2.1f'%(jg_nb, sac_nb, sf1_nb))
logging.info('[DST] [not eval name & book slots] joint goal:%2.1f slot acc: %2.1f slot f1: %2.1f'%(jg_nnnb, sac_nnnb, sf1_nnnb))
if self.cfg.bspn_mode == 'bsdx':
logging.info('[BDX] joint goal:%2.1f slot acc: %2.1f slot f1: %2.1f act f1: %2.1f'%(jg_, slot_acc_, slot_f1_, act_f1))
logging.info('[BDX] [not eval name slots] joint goal:%2.1f slot acc: %2.1f slot f1: %2.1f'%(jg_nn_, sac_nn_, sf1_nn_))
logging.info('[CTR] match: %2.1f success: %2.1f bleu: %2.1f'%(match, success, bleu))
logging.info('[CTR] ' + '; '.join(['%s: %2.1f' %(req,acc) for req, acc in req_slots_acc.items()]))
logging.info('[DOM] accuracy: single %2.1f / multi: %2.1f (%d)'%(accu_single_dom, accu_multi_dom, multi_dom_num))
if self.reader.multi_acts_record is not None:
logging.info('[MA] avg acts num %2.1f avg slots num: %2.1f '%(avg_act_num, avg_diverse_score))
return metric_result
else:
return None