def score()

in curiosity/stats.py [0:0]


    def score(self, data_path: str):
        dialogs = CuriosityDialogReader().read(data_path)
        n_assistant_messages = 0
        all_rr = []
        for d in dialogs:
            msg_history = []
            dialog_senders = d["senders"].array
            dialog_facts = d["facts"]
            dialog_fact_labels = d["fact_labels"]
            dialog_messages = d["messages"]
            for msg, sender, facts, fact_labels in zip(
                dialog_messages, dialog_senders, dialog_facts, dialog_fact_labels
            ):
                if sender == ASSISTANT_IDX:
                    context = " ".join(msg_history)
                    fact_texts = [tokens_to_str(tokens) for tokens in facts]
                    doc_scores = self._similarity.score(context, fact_texts)
                    # First get a list where first position is maximal score
                    sorted_scores = np.argsort(-np.array(doc_scores))
                    exists_rel_doc = False
                    best_rank = None
                    for rel_idx in fact_labels.array:
                        if rel_idx != -1:
                            # Then find the rank + 1 of the relevant doc
                            exists_rel_doc = True
                            # import ipdb;ipdb.set_trace();
                            rank = np.where(sorted_scores == rel_idx)[0][0] + 1
                            # We only care about the best rank, if there are multiple
                            # relevant docs
                            if best_rank is None or rank < best_rank:
                                best_rank = rank

                    # Ignore this example if there is no relevant doc
                    if exists_rel_doc:
                        all_rr.append(1 / best_rank)
                    n_assistant_messages += 1

                # Only add the actually used message after prediction
                # Add user and assistant messages
                msg_text = tokens_to_str(msg.tokens)
                msg_history.append(msg_text)
        mean_rr = np.mean(all_rr)
        log.info(f"Msgs with Facts: {len(all_rr)}")
        log.info(f"Total Assistant Msgs: {n_assistant_messages}")
        log.info(f"MRR: {mean_rr}")
        return mean_rr