in workload_applyer.py [0:0]
def __init__(self, workload=None, args=None, filename=None) -> None:
if workload is None or args is None:
assert (
filename is None
), f"you should either pass workload,args or filename to init WorkloadApplyer"
workload, args = WorkloadWriter.load_workload(filename)
# if not hasattr(args, "backend"):
# args.backend = "nccl"
# torch.distributed.init_process_group(backend=args.backend)
self.args = args
world_size = torch.distributed.get_world_size()
# args.rank = torch.distributed.get_rank()
if args.world_size != world_size:
print(
f"WARNNING: world_size is {args.world_size} when generating workload, but now world size is {world_size}"
)
args.world_size = torch.distributed.get_world_size()
device_count = torch.cuda.device_count()
self.device = args.rank % device_count
torch.cuda.set_device(self.device)
self.device = torch.cuda.current_device()
self.comm_group_info, self.pp_global_rank_info = (
self._generate_dp_tp_pp_ep_groups()
)
self.workload = workload
self.comm_type_function = {
CommType.barrier: self._apply_barrier,
CommType.broadcast: self._apply_broadcast,
CommType.reduce: self._apply_reduce,
CommType.all_reduce: self._apply_all_reduce,
CommType.all_gather: self._apply_all_gather,
CommType.reduce_scatter: self._apply_reduce_scatter,
CommType.isend: self._apply_p2pcommunication,
CommType.irecv: self._apply_p2pcommunication,
CommType.all_gather_into_tensor: self._apply_all_gather,
CommType.reduce_scatter_tensor: self._apply_reduce_scatter,
CommType.computation: self._apply_computation,
CommType.all_to_all: self._apply_all_to_all,
CommType.epoch_end: bench_logger.end_epoch,
}
cal_tuple_num = lambda t: math.prod(t[0]) + math.prod(t[1])
max_msg_size = max(
[
(
item.msg_size
if isinstance(item.msg_size, int)
else cal_tuple_num(item.msg_size)
)
for item in self.workload.workload
]
)
self.gemm_cache = {}
self.computation_aiob = False
if args.aiob_enable and args.frame == "Megatron":
self.computation_aiob = True
self.skip_computation = False
self.always_apply_gemm = False
self.gemm_iters = 1 if self.always_apply_gemm else 50
self.buffer = torch.empty(
(max_msg_size,), dtype=torch.bfloat16, device=self.device
)