def train()

in curiosity/stats.py [0:0]


    def train(self, data_path: str) -> None:
        log.info(f"Training majority classifier with: {data_path}")
        self._n_total_assistant_msgs = 0
        n_messages = 0
        dialogs = CuriosityDialogReader().read(data_path)
        log.info(f"N Dialogs: {len(dialogs)}")
        for d in dialogs:
            dialog_senders = d["senders"].array
            dialog_acts_list = d["dialog_acts"]

            for i in range(len(dialog_senders)):
                sender = dialog_senders[i]
                acts = dialog_acts_list[i].labels

                # Histogram stat per turn
                if i not in self._count_per_turn:
                    self._count_per_turn[i] = {}

                # Only care about assistant messages
                if sender == ASSISTANT_IDX:
                    for act in acts:
                        # Histogram stat per turn
                        self._count_per_turn[i][act] = (
                            self._count_per_turn[i].get(act, 0) + 1
                        )

                        # Histogram stat overall
                        self._count[act] = self._count.get(act, 0) + 1

                        # Total count
                        self._n_total_acts += 1

                    self._n_total_assistant_msgs += 1
                n_messages += 1
        self._n_total_assistant_msgs = max(1, self._n_total_assistant_msgs)
        log.info(f"N Total Assistant Messages: {self._n_total_acts}")
        log.info(f"N Total Acts: {self._n_total_assistant_msgs}")
        log.info(f"N Total Messages: {n_messages}")

        # Sort count overall
        lst = [(count, act) for act, count in self._count.items()]
        lst.sort(reverse=True)

        # Majority act in this turn
        self._majority = lst[0][1]

        for turn_idx, act_stat in self._count_per_turn.items():
            # Sort count_per_turn for each turn_idx
            lst = [(count, act) for act, count in act_stat.items()]
            lst.sort(reverse=True)

            if len(lst) != 0:
                majority_act = lst[0][1]
            else:
                majority_act = self._majority

            # Majority act in this turn
            self._majority_per_turn[turn_idx] = majority_act
            print("Turn: %d, Majority Act: %s" % (turn_idx, majority_act))

        log.info(f"Majority Act: {self._majority}")
        log.info(f"Majority Map: {self._majority_per_turn}")
        log.info(f"Count Map Per Turn: {self._count_per_turn}")
        log.info(f"Count Map: {self._count}")