def evaluate_frame()

in model/mm_dst/utils/evaluate_dst.py [0:0]


def evaluate_frame(true_frame, pred_frame, strict=True):
    """
    If strict=True,
        For each dialog_act (frame), set(slot values) must match.
        If dialog_act is incorrect, its set(slot values) is considered wrong.
    """
    count_dict = initialize_count_dict()
    count_dict["n_frames"] += 1

    # Compare Dialog Actss
    true_act = true_frame["act"] if "act" in true_frame else None
    pred_act = pred_frame["act"] if "act" in pred_frame else None
    b_correct_act = true_act == pred_act
    count_dict["n_correct_acts"] += b_correct_act
    count_dict["n_true_acts"] += "act" in true_frame
    count_dict["n_pred_acts"] += "act" in pred_frame

    # (1) Compare Slots
    #true_frame_slot_values = {f"{k}={v}" for k, v in true_frame.get("slots", [])}
    #pred_frame_slot_values = {f"{k}={v}" for k, v in pred_frame.get("slots", [])}

    true_frame_slot_values = set()
    pred_frame_slot_values = set()

    for k, v in true_frame.get("slots", []):
        if k in set(['availableSizes']):
            # For availableSizes, we expect that the type is <list>.
            # Otherwise, try converting it to a <list>.
            if type(v) == str:
                try:
                    v = list(eval(v))
                except:
                    v = [v]

            elif type(v) == tuple or type(v) == set:
                v = list(v)

            # Sort the elements to get consistent ordering.
            # For slots with a list of elements, all elements must be captured.
            if type(v) == list:
                v.sort()

        true_frame_slot_values.add(f"{k}={v}")

    for k, v in pred_frame.get("slots", []):
        if k in set(['availableSizes']):
            if type(v) == str:
                try:
                    v = list(eval(v))
                except:
                    v = [v]

            elif type(v) == tuple or type(v) == set:
                v = list(v)
            if type(v) == list:
                v.sort()

        pred_frame_slot_values.add(f"{k}={v}")

    count_dict["n_true_slots"] += len(true_frame_slot_values)
    count_dict["n_pred_slots"] += len(pred_frame_slot_values)

    if strict and not b_correct_act:
        pass
    else:
        count_dict["n_correct_slots"] += len(
            true_frame_slot_values.intersection(pred_frame_slot_values)
        )

    # Debug only
    # if len(true_frame_slot_values.intersection(pred_frame_slot_values)) != len(pred_frame_slot_values):
    # print(true_frame_slot_values)
    # print(pred_frame_slot_values)
    # print(len(true_frame_slot_values.intersection(pred_frame_slot_values)) == len(pred_frame_slot_values))
    # print('--')

    # (2) Compare Request slots
    true_frame_request_slot_values = {rs for rs in true_frame.get("request_slots", [])}
    pred_frame_request_slot_values = {rs for rs in pred_frame.get("request_slots", [])}
    # print(true_frame_request_slot_values)

    count_dict["n_true_request_slots"] += len(true_frame_request_slot_values)
    count_dict["n_pred_request_slots"] += len(pred_frame_request_slot_values)

    if strict and not b_correct_act:
        pass
    else:
        count_dict["n_correct_request_slots"] += len(
            true_frame_request_slot_values.intersection(pred_frame_request_slot_values)
        )

    # (3) Compare Objects
    true_frame_object_values = {
        object_id for object_id in true_frame.get("objects", [])
    }
    pred_frame_object_values = {
        object_id for object_id in pred_frame.get("objects", [])
    }
    # print(true_frame_object_values)

    count_dict["n_true_objects"] += len(true_frame_object_values)
    count_dict["n_pred_objects"] += len(pred_frame_object_values)

    if strict and not b_correct_act:
        pass
    else:
        count_dict["n_correct_objects"] += len(
            true_frame_object_values.intersection(pred_frame_object_values)
        )

    # (4) Joint
    count_dict["n_correct_beliefs"] += (
        b_correct_act
        and true_frame_slot_values == pred_frame_slot_values
        and true_frame_request_slot_values == pred_frame_request_slot_values
        and true_frame_object_values == pred_frame_object_values
    )

    return count_dict