in curiosity/stats.py [0:0]
def score(self, data_path: str) -> float:
log.info(f"Scoring majority classifier with: {data_path}")
dialogs = CuriosityDialogReader().read(data_path)
log.info(f"N Dialogs: {len(dialogs)}")
correct = 0
total = 0
n_messages = 0
for d in dialogs:
dialog_senders = d["senders"].array
dialog_likes = d["likes"]
for sender, liked in zip(dialog_senders, dialog_likes):
if sender == ASSISTANT_IDX:
label = liked.label
# If liked and majority class in training was liked
if label == "liked" and self._like_all:
correct += 1
# If not liked and majority class in training was not liked
elif label == "liked" and not self._like_all:
correct += 1
total += 1
n_messages += 1
log.info(f"N Correct Assistant Messages: {correct}")
log.info(f"N Total Assistant Messages: {total}")
log.info(f"N Total Messages: {n_messages}")
total = max(1, total)
return correct / total