in workload_generator/workload_generator.py [0:0]
def __call__(self):
args = self.args
self.workload = Workload()
self.init()
self.workload.append(LogItem(comm_type=CommType.epoch_end))
for i in range(args.epoch_num):
if args.pipeline_model_parallel > 1 and args.frame != "collective_test":
self.with_pipeline_forward_backward()
self.step()
else:
for _ in range(args.num_microbatches):
self.forward()
self.backward()
self.step()
self.workload.append(LogItem(comm_type=CommType.epoch_end))
return self.workload