in model/mm_dst/utils/evaluate_dst.py [0:0]
def evaluate_frame(true_frame, pred_frame, strict=True):
"""
If strict=True,
For each dialog_act (frame), set(slot values) must match.
If dialog_act is incorrect, its set(slot values) is considered wrong.
"""
count_dict = initialize_count_dict()
count_dict["n_frames"] += 1
# Compare Dialog Actss
true_act = true_frame["act"] if "act" in true_frame else None
pred_act = pred_frame["act"] if "act" in pred_frame else None
b_correct_act = true_act == pred_act
count_dict["n_correct_acts"] += b_correct_act
count_dict["n_true_acts"] += "act" in true_frame
count_dict["n_pred_acts"] += "act" in pred_frame
# (1) Compare Slots
#true_frame_slot_values = {f"{k}={v}" for k, v in true_frame.get("slots", [])}
#pred_frame_slot_values = {f"{k}={v}" for k, v in pred_frame.get("slots", [])}
true_frame_slot_values = set()
pred_frame_slot_values = set()
for k, v in true_frame.get("slots", []):
if k in set(['availableSizes']):
# For availableSizes, we expect that the type is <list>.
# Otherwise, try converting it to a <list>.
if type(v) == str:
try:
v = list(eval(v))
except:
v = [v]
elif type(v) == tuple or type(v) == set:
v = list(v)
# Sort the elements to get consistent ordering.
# For slots with a list of elements, all elements must be captured.
if type(v) == list:
v.sort()
true_frame_slot_values.add(f"{k}={v}")
for k, v in pred_frame.get("slots", []):
if k in set(['availableSizes']):
if type(v) == str:
try:
v = list(eval(v))
except:
v = [v]
elif type(v) == tuple or type(v) == set:
v = list(v)
if type(v) == list:
v.sort()
pred_frame_slot_values.add(f"{k}={v}")
count_dict["n_true_slots"] += len(true_frame_slot_values)
count_dict["n_pred_slots"] += len(pred_frame_slot_values)
if strict and not b_correct_act:
pass
else:
count_dict["n_correct_slots"] += len(
true_frame_slot_values.intersection(pred_frame_slot_values)
)
# Debug only
# if len(true_frame_slot_values.intersection(pred_frame_slot_values)) != len(pred_frame_slot_values):
# print(true_frame_slot_values)
# print(pred_frame_slot_values)
# print(len(true_frame_slot_values.intersection(pred_frame_slot_values)) == len(pred_frame_slot_values))
# print('--')
# (2) Compare Request slots
true_frame_request_slot_values = {rs for rs in true_frame.get("request_slots", [])}
pred_frame_request_slot_values = {rs for rs in pred_frame.get("request_slots", [])}
# print(true_frame_request_slot_values)
count_dict["n_true_request_slots"] += len(true_frame_request_slot_values)
count_dict["n_pred_request_slots"] += len(pred_frame_request_slot_values)
if strict and not b_correct_act:
pass
else:
count_dict["n_correct_request_slots"] += len(
true_frame_request_slot_values.intersection(pred_frame_request_slot_values)
)
# (3) Compare Objects
true_frame_object_values = {
object_id for object_id in true_frame.get("objects", [])
}
pred_frame_object_values = {
object_id for object_id in pred_frame.get("objects", [])
}
# print(true_frame_object_values)
count_dict["n_true_objects"] += len(true_frame_object_values)
count_dict["n_pred_objects"] += len(pred_frame_object_values)
if strict and not b_correct_act:
pass
else:
count_dict["n_correct_objects"] += len(
true_frame_object_values.intersection(pred_frame_object_values)
)
# (4) Joint
count_dict["n_correct_beliefs"] += (
b_correct_act
and true_frame_slot_values == pred_frame_slot_values
and true_frame_request_slot_values == pred_frame_request_slot_values
and true_frame_object_values == pred_frame_object_values
)
return count_dict