in curiosity/stats.py [0:0]
def train(self, data_path: str) -> None:
log.info(f"Training majority classifier with: {data_path}")
self._n_total_assistant_msgs = 0
self._n_liked_assistant_msgs = 0
n_messages = 0
dialogs = CuriosityDialogReader().read(data_path)
log.info(f"N Dialogs: {len(dialogs)}")
for d in dialogs:
dialog_senders = d["senders"].array
dialog_likes = d["likes"]
for sender, liked in zip(dialog_senders, dialog_likes):
# Only care about assistant messages
if sender == ASSISTANT_IDX:
if liked.label == "liked":
self._n_liked_assistant_msgs += 1
self._n_total_assistant_msgs += 1
n_messages += 1
self._n_total_assistant_msgs = max(1, self._n_total_assistant_msgs)
log.info(f"N Liked Assistant Messages: {self._n_liked_assistant_msgs}")
log.info(f"N Total Assistant Messages: {self._n_total_assistant_msgs}")
log.info(f"N Total Messages: {n_messages}")
if (self._n_liked_assistant_msgs / self._n_total_assistant_msgs) > 0.5:
self._like_all = True
else:
self._like_all = False
log.info(f"Majority Class Liked: {self._like_all}")