def split_data_by_epoch()

in visualize/generate.py [0:0]


def split_data_by_epoch(is_comm: int,log_items: List[LogItem]) -> Dict[str, List[LogItem]]:
    data_by_epoch = {'init': []}
    epoch_count = 0
    current_epoch = 'init'

    for log_item in log_items:
        if is_comm and log_item.comm_type == CommType.computation:
            continue

        if log_item.is_epoch_end():
            data_by_epoch[current_epoch].append(log_item)
            current_epoch = f'epoch_{epoch_count}'
            data_by_epoch[current_epoch] = []
            epoch_count += 1
        else:
            data_by_epoch[current_epoch].append(log_item)
    if not data_by_epoch[current_epoch]:
        del data_by_epoch[current_epoch]
    return data_by_epoch