def train()

in curiosity/stats.py [0:0]


    def train(self, data_path: str) -> None:
        log.info(f"Training majority classifier with: {data_path}")
        self._n_total_assistant_msgs = 0
        self._n_liked_assistant_msgs = 0
        n_messages = 0
        dialogs = CuriosityDialogReader().read(data_path)
        log.info(f"N Dialogs: {len(dialogs)}")
        for d in dialogs:
            dialog_senders = d["senders"].array
            dialog_likes = d["likes"]
            for sender, liked in zip(dialog_senders, dialog_likes):
                # Only care about assistant messages
                if sender == ASSISTANT_IDX:
                    if liked.label == "liked":
                        self._n_liked_assistant_msgs += 1
                    self._n_total_assistant_msgs += 1
                n_messages += 1
        self._n_total_assistant_msgs = max(1, self._n_total_assistant_msgs)
        log.info(f"N Liked Assistant Messages: {self._n_liked_assistant_msgs}")
        log.info(f"N Total Assistant Messages: {self._n_total_assistant_msgs}")
        log.info(f"N Total Messages: {n_messages}")
        if (self._n_liked_assistant_msgs / self._n_total_assistant_msgs) > 0.5:
            self._like_all = True
        else:
            self._like_all = False
        log.info(f"Majority Class Liked: {self._like_all}")