def __init__()

in workload_applyer.py [0:0]


    def __init__(self, workload=None, args=None, filename=None) -> None:
        if workload is None or args is None:
            assert (
                filename is None
            ), f"you should either pass workload,args or filename to init WorkloadApplyer"
            workload, args = WorkloadWriter.load_workload(filename)
        # if not hasattr(args, "backend"):
        #     args.backend = "nccl"
        # torch.distributed.init_process_group(backend=args.backend)
        self.args = args
        world_size = torch.distributed.get_world_size()
        # args.rank = torch.distributed.get_rank()
        if args.world_size != world_size:
            print(
                f"WARNNING: world_size is {args.world_size} when generating workload, but now world size is {world_size}"
            )
            args.world_size = torch.distributed.get_world_size()
        device_count = torch.cuda.device_count()
        self.device = args.rank % device_count
        torch.cuda.set_device(self.device)
        self.device = torch.cuda.current_device()
        self.comm_group_info, self.pp_global_rank_info = (
            self._generate_dp_tp_pp_ep_groups()
        )
        self.workload = workload
        self.comm_type_function = {
            CommType.barrier: self._apply_barrier,
            CommType.broadcast: self._apply_broadcast,
            CommType.reduce: self._apply_reduce,
            CommType.all_reduce: self._apply_all_reduce,
            CommType.all_gather: self._apply_all_gather,
            CommType.reduce_scatter: self._apply_reduce_scatter,
            CommType.isend: self._apply_p2pcommunication,
            CommType.irecv: self._apply_p2pcommunication,
            CommType.all_gather_into_tensor: self._apply_all_gather,
            CommType.reduce_scatter_tensor: self._apply_reduce_scatter,
            CommType.computation: self._apply_computation,
            CommType.all_to_all: self._apply_all_to_all,
            CommType.epoch_end: bench_logger.end_epoch,

        }

        cal_tuple_num = lambda t: math.prod(t[0]) + math.prod(t[1])
        max_msg_size = max(
            [
                (
                    item.msg_size
                    if isinstance(item.msg_size, int)
                    else cal_tuple_num(item.msg_size)
                )
                for item in self.workload.workload
            ]
        )
        self.gemm_cache = {}
        self.computation_aiob = False
        if args.aiob_enable and args.frame == "Megatron":
            self.computation_aiob = True

        self.skip_computation = False
        self.always_apply_gemm = False
        self.gemm_iters = 1 if self.always_apply_gemm else 50
        self.buffer = torch.empty(
            (max_msg_size,), dtype=torch.bfloat16, device=self.device
        )