in curiosity/stats.py [0:0]
def score(self, data_path: str):
dialogs = CuriosityDialogReader().read(data_path)
n_assistant_messages = 0
all_rr = []
for d in dialogs:
msg_history = []
dialog_senders = d["senders"].array
dialog_facts = d["facts"]
dialog_fact_labels = d["fact_labels"]
dialog_messages = d["messages"]
for msg, sender, facts, fact_labels in zip(
dialog_messages, dialog_senders, dialog_facts, dialog_fact_labels
):
if sender == ASSISTANT_IDX:
context = " ".join(msg_history)
fact_texts = [tokens_to_str(tokens) for tokens in facts]
doc_scores = self._similarity.score(context, fact_texts)
# First get a list where first position is maximal score
sorted_scores = np.argsort(-np.array(doc_scores))
exists_rel_doc = False
best_rank = None
for rel_idx in fact_labels.array:
if rel_idx != -1:
# Then find the rank + 1 of the relevant doc
exists_rel_doc = True
# import ipdb;ipdb.set_trace();
rank = np.where(sorted_scores == rel_idx)[0][0] + 1
# We only care about the best rank, if there are multiple
# relevant docs
if best_rank is None or rank < best_rank:
best_rank = rank
# Ignore this example if there is no relevant doc
if exists_rel_doc:
all_rr.append(1 / best_rank)
n_assistant_messages += 1
# Only add the actually used message after prediction
# Add user and assistant messages
msg_text = tokens_to_str(msg.tokens)
msg_history.append(msg_text)
mean_rr = np.mean(all_rr)
log.info(f"Msgs with Facts: {len(all_rr)}")
log.info(f"Total Assistant Msgs: {n_assistant_messages}")
log.info(f"MRR: {mean_rr}")
return mean_rr