in main.py [0:0]
def evaluation(true_convers, predict_convers):
assert len(true_convers) == len(predict_convers)
all_turn_pds = set()
all_turn_gts = set()
all_turn_dependent_pds = set()
all_turn_dependent_gts = set()
exact_match_turn = []
iou_turn = []
conversation_all_correct = []
for key in true_convers:
assert key in predict_convers
assert len(true_convers[key]) == len(predict_convers[key])
conversation_correct = 1
# predicted_state_tracker = {}
# groundtruth_violation_tracker = {}
for idx in range(len(true_convers[key])):
gt = true_convers[key][idx]
pd = predict_convers[key][idx]
assert gt['utteranceId'] == pd['utteranceId']
for i, v in pd['violations']:
all_turn_pds.add((pd['utteranceId'], i, v))
if i in gt['dialog_states']:
all_turn_dependent_pds.add((pd['utteranceId'], i, v))
for i, v in gt['violations']:
all_turn_gts.add((gt['utteranceId'], i, v))
if i in pd['dialog_states']:
all_turn_dependent_gts.add((gt['utteranceId'], i, v))
gt_vios = set([tuple(x) for x in gt['violations']])
pd_vios = set(pd['violations'])
if gt_vios == pd_vios:
this_iou = 1
else:
this_iou = len(gt_vios.intersection(pd_vios)) / len(gt_vios.union(pd_vios))
iou_turn.append(this_iou)
this_exact_match = 1 if this_iou == 1 else 0
exact_match_turn.append(this_exact_match)
if not this_exact_match:
conversation_correct = 0
conversation_all_correct.append(conversation_correct)
result_this_iteration = calculate_both_prfs(all_turn_gts, all_turn_pds, all_turn_dependent_gts,
all_turn_dependent_pds)
pprint(result_this_iteration)
result_this_iteration['exact_match'] = np.mean(exact_match_turn)
result_this_iteration['iou'] = np.mean(iou_turn)
result_this_iteration['conversation_correct'] = np.mean(conversation_all_correct)
return result_this_iteration