def benchTime()

in train/comms/pt/comms.py [0:0]


    def benchTime(self, index, commsParams, backendFuncs):
        # Get NW stack specific parameters
        (
            local_rank,
            global_rank,
            world_size,
            group,
            curDevice,
            curHwDevice,
            allSizes,
            computeFunc,
        ) = self.initCollectiveArgs(commsParams)

        backendFuncs.sync_barrier(self.collectiveArgs)
        if global_rank == 0:
            self.printPreamble(commsParams)

        for curSize in allSizes:
            results = {}
            timeUsElapsedList = []
            quantTimeElapsedList = []
            dequantTimeElapsedList = []
            numElements = int(curSize // commsParams.element_size)
            collectiveFunc = self.backendFuncs.noop
            collectiveFunc_pair = self.backendFuncs.noop

            if (
                commsParams.mode != "compute"
            ):  # comms specific initializations if not in compute-only mode
                # set corresponding function pointers
                if commsParams.collective != "pt2pt":
                    collectiveFunc = backendFuncs.collectiveFunc[commsParams.collective]

                (
                    self.collectiveArgs.ipTensor,
                    self.collectiveArgs.opTensor,
                ) = self.prepComm(
                    curComm={
                        "in_msg_size": numElements,
                        "out_msg_size": numElements,
                        "world_size": world_size,
                    },
                    commsParams=commsParams,
                )

            # Setup the arguments.
            self.collectiveArgs.dataSize = curSize
            self.collectiveArgs.numElements = numElements
            self.collectiveArgs.waitObj = []
            results["numElements"] = numElements

            if (
                commsParams.pair and commsParams.mode != "compute"
            ):  # comms-pair specific initializations if not in compute-only mode:
                # set corresponding function pointers
                collectiveFunc_pair = backendFuncs.collectiveFunc[
                    commsParams.collective_pair
                ]
                # TODO: allow user to set specific size
                # Setup the arguments.
                self.collectiveArgs.dataSize_pair = curSize
                self.collectiveArgs.numElements_pair = int(
                    self.collectiveArgs.dataSize_pair // commsParams.element_size
                )
                results["numElements_pair"] = self.collectiveArgs.numElements_pair
                (
                    self.collectiveArgs.ipTensor_pair,
                    self.collectiveArgs.opTensor_pair,
                ) = self.prepComm(
                    curComm={
                        "in_msg_size": self.collectiveArgs.numElements_pair,
                        "out_msg_size": self.collectiveArgs.numElements_pair,
                        "world_size": world_size,
                    },
                    commsParams=commsParams,
                )

            # self.collectiveArgs has all the information on the experiment.
            if commsParams.collective == "pt2pt":
                results.update(self.runPt2Pt())

                timeUsElapsedList = [
                    np.mean(np.array(results["pingPerIterNS"])) / 1e3,
                    np.mean(np.array(results["pingPongPerIterNS"])) / 1e3,
                    results["avgUniBW"],
                    results["avgBiBW"],
                ]  # time in US
                if (
                    global_rank in self.collectiveArgs.src_ranks
                    or global_rank in self.collectiveArgs.dst_ranks
                ):
                    logger.debug(timeUsElapsedList)
            else:
                results.update(
                    self.runColl(
                        comm_fn=collectiveFunc,
                        compute_fn=computeFunc,
                        comm_fn_pair=collectiveFunc_pair,
                    )
                )
                timeUsElapsedList = [results["timeUS"]]

            # perfom data validation check on the final opTensor
            if commsParams.dcheck == 1:
                self.dcheck(commsParams, curSize, self.collectiveArgs.opTensor)

            backendFuncs.clear_memory(self.collectiveArgs)

            # gather quantization overhead if enabled
            if commsParams.bitwidth < 32:
                # calculate average (de-)quantization overhead
                results["quantTimeUS"] = (
                    self.collectiveArgs.quant_time.getTimeUS()
                    / self.collectiveArgs.numIters
                )
                results["dequantTimeUS"] = (
                    self.collectiveArgs.dequant_time.getTimeUS()
                    / self.collectiveArgs.numIters
                )
                quantTimeElapsedList.append(results["quantTimeUS"])
                dequantTimeElapsedList.append(results["dequantTimeUS"])

                logger.debug(quantTimeElapsedList)
                quantTimeElapsedList = self.gatherBenchTime(
                    self.collectiveArgs, commsParams, quantTimeElapsedList
                )
                dequantTimeElapsedList = self.gatherBenchTime(
                    self.collectiveArgs, commsParams, dequantTimeElapsedList
                )

            # gather and report performance to stdout
            tensorList = self.gatherBenchTime(
                self.collectiveArgs, commsParams, timeUsElapsedList
            )
            if global_rank == 0:
                self.reportBenchTime(
                    commsParams,
                    results,
                    tensorList,
                    quantTimeElapsedList,
                    dequantTimeElapsedList,
                )

            self.backendFuncs.sync_barrier(
                self.collectiveArgs, desc=f"curSize_{curSize}"
            )

        comms_utils.clearQuantCommCtx(self.collectiveArgs)

        # wait rank 0 reports results to avoid other ranks mess up the output
        self.backendFuncs.sync_barrier(self.collectiveArgs, "benchtime")