in curiosity/stats.py [0:0]
def score(self, data_path: str) -> float:
log.info(f"Scoring majority classifier with: {data_path}")
dialogs = CuriosityDialogReader().read(data_path)
log.info(f"N Dialogs: {len(dialogs)}")
correct = 0
total = 0
n_messages = 0
for d in dialogs:
dialog_senders = d["senders"].array
dialog_acts_list = d["dialog_acts"]
for i in range(len(dialog_senders)):
sender = dialog_senders[i]
acts = dialog_acts_list[i].labels
if sender != ASSISTANT_IDX:
for act in acts:
if i in self._majority_per_turn:
if act == self._majority_per_turn[i]:
correct += 1
else:
if act == self._majority:
correct += 1
total += len(acts)
n_messages += 1
log.info(f"N Correct Acts: {correct}")
log.info(f"N Total Acts: {total}")
log.info(f"N Total Messages: {n_messages}")
total = max(1, total)
n_messages = max(1, n_messages)
p = correct / n_messages # assumes 1 prediction per message
r = correct / total
f1 = 2 * (p * r) / (p + r)
return f1