in visualize/generate.py [0:0]
def count_by_epoch(data_by_epoch: Dict[str, List[LogItem]]) -> Dict[str, Dict[str, int]]:
comm_type_counts = {}
for epoch, log_items in data_by_epoch.items():
if epoch == 'init':
continue
comm_type_counts[epoch] = {}
for log_item in log_items:
if log_item.comm_type == CommType.epoch_end:
continue
comm_type_str = log_item.comm_type.name
if comm_type_str not in comm_type_counts[epoch]:
comm_type_counts[epoch][comm_type_str] = 0
comm_type_counts[epoch][comm_type_str] += 1
return comm_type_counts