in curiosity/stats.py [0:0]
def train(self, data_path: str) -> None:
log.info(f"Training majority classifier with: {data_path}")
self._n_total_assistant_msgs = 0
n_messages = 0
dialogs = CuriosityDialogReader().read(data_path)
log.info(f"N Dialogs: {len(dialogs)}")
for d in dialogs:
dialog_senders = d["senders"].array
dialog_acts_list = d["dialog_acts"]
for i in range(len(dialog_senders)):
sender = dialog_senders[i]
acts = dialog_acts_list[i].labels
# Histogram stat per turn
if i not in self._count_per_turn:
self._count_per_turn[i] = {}
# Only care about assistant messages
if sender == ASSISTANT_IDX:
for act in acts:
# Histogram stat per turn
self._count_per_turn[i][act] = (
self._count_per_turn[i].get(act, 0) + 1
)
# Histogram stat overall
self._count[act] = self._count.get(act, 0) + 1
# Total count
self._n_total_acts += 1
self._n_total_assistant_msgs += 1
n_messages += 1
self._n_total_assistant_msgs = max(1, self._n_total_assistant_msgs)
log.info(f"N Total Assistant Messages: {self._n_total_acts}")
log.info(f"N Total Acts: {self._n_total_assistant_msgs}")
log.info(f"N Total Messages: {n_messages}")
# Sort count overall
lst = [(count, act) for act, count in self._count.items()]
lst.sort(reverse=True)
# Majority act in this turn
self._majority = lst[0][1]
for turn_idx, act_stat in self._count_per_turn.items():
# Sort count_per_turn for each turn_idx
lst = [(count, act) for act, count in act_stat.items()]
lst.sort(reverse=True)
if len(lst) != 0:
majority_act = lst[0][1]
else:
majority_act = self._majority
# Majority act in this turn
self._majority_per_turn[turn_idx] = majority_act
print("Turn: %d, Majority Act: %s" % (turn_idx, majority_act))
log.info(f"Majority Act: {self._majority}")
log.info(f"Majority Map: {self._majority_per_turn}")
log.info(f"Count Map Per Turn: {self._count_per_turn}")
log.info(f"Count Map: {self._count}")