in model/mm_dst/utils/evaluate_dst.py [0:0]
def evaluate_from_flat_list(d_true, d_pred):
"""
<list>d_true and <list>d_pred are in the following format:
(Each element represents a single turn, with (multiple) frames)
[
[
{
'act': <str>,
'slots': [
[
SLOT_NAME, SLOT_VALUE
],
...
],
'request_slots': [ SLOT_NAME, ... ],
'objects': [ <int> ]
},
[End of a frame]
...
],
[End of a turn]
...
]
"""
c = initialize_count_dict()
# Count # corrects & # wrongs
for i in range(len(d_true)):
true_turn = d_true[i]
pred_turn = d_pred[i]
turn_evaluation = evaluate_turn(true_turn, pred_turn)
c = add_dicts(c, turn_evaluation)
# Calculate metrics
joint_accuracy = c["n_correct_beliefs"] / c["n_frames"]
act_rec, act_prec, act_f1 = rec_prec_f1(
n_correct=c["n_correct_acts"], n_true=c["n_true_acts"], n_pred=c["n_pred_acts"]
)
slot_rec, slot_prec, slot_f1 = rec_prec_f1(
n_correct=c["n_correct_slots"],
n_true=c["n_true_slots"],
n_pred=c["n_pred_slots"],
)
request_slot_rec, request_slot_prec, request_slot_f1 = rec_prec_f1(
n_correct=c["n_correct_request_slots"],
n_true=c["n_true_request_slots"],
n_pred=c["n_pred_request_slots"],
)
object_rec, object_prec, object_f1 = rec_prec_f1(
n_correct=c["n_correct_objects"],
n_true=c["n_true_objects"],
n_pred=c["n_pred_objects"],
)
# Calculate std err
act_f1_stderr = d_f1(c["n_true_acts"], c["n_pred_acts"], c["n_correct_acts"])
slot_f1_stderr = d_f1(c["n_true_slots"], c["n_pred_slots"], c["n_correct_slots"])
request_slot_f1_stderr = d_f1(
c["n_true_request_slots"],
c["n_pred_request_slots"],
c["n_correct_request_slots"],
)
object_f1_stderr = d_f1(
c["n_true_objects"], c["n_pred_objects"], c["n_correct_objects"]
)
return {
"joint_accuracy": joint_accuracy,
"act_rec": act_rec,
"act_prec": act_prec,
"act_f1": act_f1,
"act_f1_stderr": act_f1_stderr,
"slot_rec": slot_rec,
"slot_prec": slot_prec,
"slot_f1": slot_f1,
"slot_f1_stderr": slot_f1_stderr,
"request_slot_rec": request_slot_rec,
"request_slot_prec": request_slot_prec,
"request_slot_f1": request_slot_f1,
"request_slot_f1_stderr": request_slot_f1_stderr,
"object_rec": object_rec,
"object_prec": object_prec,
"object_f1": object_f1,
"object_f1_stderr": object_f1_stderr,
}