in visualize/generate.py [0:0]
def split_data_by_epoch(is_comm: int,log_items: List[LogItem]) -> Dict[str, List[LogItem]]:
data_by_epoch = {'init': []}
epoch_count = 0
current_epoch = 'init'
for log_item in log_items:
if is_comm and log_item.comm_type == CommType.computation:
continue
if log_item.is_epoch_end():
data_by_epoch[current_epoch].append(log_item)
current_epoch = f'epoch_{epoch_count}'
data_by_epoch[current_epoch] = []
epoch_count += 1
else:
data_by_epoch[current_epoch].append(log_item)
if not data_by_epoch[current_epoch]:
del data_by_epoch[current_epoch]
return data_by_epoch