def benchTime()

in train/comms/pt/commsTraceReplay.py [0:0]


    def benchTime(self, commsParams):
        """
        The json format is expecting to be either
        {
            "marker_stack": ["## all2all ##"]
            "comms": "all_to_allv",
            "in_msg_size": 10357149,
            "out_msg_size": 23093760,
            "in_split": [],
            "out_split": [],
            "dtype": "Int"
        },
        or w/o in/out_split
        {
            "marker_stack": ["## all2all ##"]
            "comms": "all_reduce",
            "in_msg_size": 1048576,
            "out_msg_size": 1048576,
            "dtype": "Int"
        }
        or wait/barrier
        {
            "marker_stack": ["## all2all ##"]
            "comms": "wait",
        }
        NOTE:
            - this format is subject to be changed anytime
            - the unit of all size fields is # of elements (not bytes)
        """
        # warm-up
        if self.do_warm_up:
            self.warmUpBench(commsParams)

        # sync everything before starting real runs
        self.backendFuncs.sync_barrier(self.collectiveArgs)

        if self.backendFuncs.get_global_rank() == 0:
            print(
                f"\n+ {self.max_msg_cnt} messages in the trace...replaying (if present) {list(self.allowList)}"
            )
            for coll, sizes in self.collInMsgSizes.items():
                logger.info(f"\t{coll}: {len(sizes)}")

        coll_in_batch_num = 0
        for cnt, curComm in enumerate(self.comms_trace[: self.max_msg_cnt]):
            collName = paramToCommName(curComm["comms"])
            if collName not in self.allowList:
                continue

            curBlocks = curComm["marker_stack"] if "marker_stack" in curComm else []
            curBlockStack = (
                " ".join(curBlocks) if len(curBlocks) > 0 else "Unamed/Unknown"
            )

            if self.backendFuncs.get_global_rank() == 0:
                logger.debug(
                    f"[Rank {self.collectiveArgs.global_rank:3}] Replaying \n{str(curComm)}\n"
                )
                print(f"[{cnt} / {self.max_msg_cnt}]", end="\r")

            # read fields and prepare the tensors
            (
                self.collectiveArgs.ipTensor,
                self.collectiveArgs.opTensor,
            ) = self.prepComms(curComm, commsParams)

            if self.colls_per_batch > 0 and coll_in_batch_num == 0:
                batch_begin = time.monotonic()

            (latency, global_latency) = self.runComms(collName, curBlockStack)

            # calculating batch latency (batch defined by --colls-per-batch)
            if collName == "wait" and self.colls_per_batch > 0:
                coll_in_batch_num += 1
                if coll_in_batch_num == self.colls_per_batch:
                    batch_latency = (
                        time.monotonic() - batch_begin
                    ) * 1e3  # make it millisecond
                    coll_in_batch_num = 0
                    self.batchLat.append(batch_latency)

            # perfom data validation check on the final opTensor
            if self.is_blocking and commsParams.dcheck == 1 and collName not in ("wait","barrier"):
                commsParams.collective = collName
                commsParams.srcOrDst = curComm["root"] if "root" in curComm else 0
                self.dcheck(commsParams, curComm["out_msg_size"], self.collectiveArgs.opTensor)

            self.collLat[collName].append(latency)

            curComm["seqnum"] = cnt
            curComm["latency_us"] = latency
            curComm["global_latency_us"] = global_latency
            curComm["quant_us"] = self.collectiveArgs.quant_time.getTimeUS()
            curComm["dequant_us"] = self.collectiveArgs.dequant_time.getTimeUS()
            self.totalCommsLatency += latency
            # Keep a copy of trace with performance (latency) and seqnum
            self.traceWithPerf.append(curComm)

            # categorized by the marker
            for curBlock in curBlocks:
                # elem_size = self.collectiveArgs.ipTensor.element_size()
                self.comms_blocks[curBlock].append(curComm)

            if self.backendFuncs.get_global_rank() == 0:
                logger.info(
                    f"[{cnt} / {self.max_msg_cnt}] Replayed {collName} in block [{curBlockStack}]... {global_latency:.2f} us"
                )

        # make sure all ops are completed
        self.backendFuncs.sync_barrier(self.collectiveArgs)
        self.backendFuncs.clear_memory(self.collectiveArgs)