def count_by_epoch()

in visualize/generate.py [0:0]


def count_by_epoch(data_by_epoch: Dict[str, List[LogItem]]) -> Dict[str, Dict[str, int]]:
    comm_type_counts = {}

    for epoch, log_items in data_by_epoch.items():
        if epoch == 'init':
            continue

        comm_type_counts[epoch] = {}
        for log_item in log_items:
            if log_item.comm_type == CommType.epoch_end:
                continue

            comm_type_str = log_item.comm_type.name
            if comm_type_str not in comm_type_counts[epoch]:
                comm_type_counts[epoch][comm_type_str] = 0
            comm_type_counts[epoch][comm_type_str] += 1

    return comm_type_counts