def domain_eval()

in E2E_TOD/eval.py [0:0]


    def domain_eval(self, data, eval_dial_list = None):
        dials = self.pack_dial(data)
        corr_single, total_single, corr_multi, total_multi = 0, 0, 0, 0

        dial_num = 0
        for dial_id in dials:
            if eval_dial_list and dial_id+'.json' not in eval_dial_list:
                continue
            dial_num += 1
            dial = dials[dial_id]
            wrong_pred = []

            prev_constraint_dict = {}
            prev_turn_domain = ['general']

            for turn_num, turn in enumerate(dial):
                if turn_num == 0:
                    continue
                true_domains = self.reader.dspan_to_domain(turn['dspn'])
                if self.cfg.enable_dspn:
                    pred_domains = self.reader.dspan_to_domain(turn['dspn_gen'])
                else:
                    turn_dom_bs = []
                    if self.cfg.enable_bspn and not self.cfg.use_true_bspn_for_ctr_eval and \
                        (cfg.bspn_mode == 'bspn' or self.cfg.enable_dst):
                        constraint_dict = self.reader.bspan_to_constraint_dict(turn['bspn_gen'])
                    else:
                        constraint_dict = self.reader.bspan_to_constraint_dict(turn['bspn'])
                    for domain in constraint_dict:
                        if domain not in prev_constraint_dict:
                            turn_dom_bs.append(domain)
                        elif prev_constraint_dict[domain] != constraint_dict[domain]:
                            turn_dom_bs.append(domain)
                    aspn = 'aspn' if not self.cfg.enable_aspn else 'aspn_gen'
                    turn_dom_da = []
                    for a in turn[aspn].split():
                        if a[1:-1] in ontology.all_domains + ['general']:
                            turn_dom_da.append(a[1:-1])

                    # get turn domain
                    turn_domain = turn_dom_bs
                    for dom in turn_dom_da:
                        if dom != 'booking' and dom not in turn_domain:
                            turn_domain.append(dom)
                    if not turn_domain:
                        turn_domain = prev_turn_domain
                    if len(turn_domain) == 2 and 'general' in turn_domain:
                        turn_domain.remove('general')
                    if len(turn_domain) == 2:
                        if len(prev_turn_domain) == 1 and prev_turn_domain[0] == turn_domain[1]:
                            turn_domain = turn_domain[::-1]
                    prev_turn_domain = copy.deepcopy(turn_domain)
                    prev_constraint_dict = copy.deepcopy(constraint_dict)

                    turn['dspn_gen'] = ' '.join(['['+d+']' for d in turn_domain])
                    pred_domains = {}
                    for d in turn_domain:
                        pred_domains['['+d+']'] = 1

                if len(true_domains) == 1:
                    total_single += 1
                    if pred_domains == true_domains:
                        corr_single += 1
                    else:
                        wrong_pred.append(str(turn['turn_num']))
                        turn['wrong_domain'] = 'x'
                else:
                    total_multi += 1
                    if pred_domains == true_domains:
                        corr_multi += 1
                    else:
                        wrong_pred.append(str(turn['turn_num']))
                        turn['wrong_domain'] = 'x'

            # dialog inform metric record
            dial[0]['wrong_domain'] = ' '.join(wrong_pred)
        accu_single = corr_single / (total_single + 1e-10)
        accu_multi = corr_multi / (total_multi + 1e-10)
        return accu_single * 100, accu_multi * 100, total_multi