def runColl()

in train/comms/pt/comms.py [0:0]


    def runColl(self, comm_fn=None, compute_fn=None, comm_fn_pair=None):
        self.backendFuncs.complete_accel_ops(self.collectiveArgs, initOp=True)
        self.backendFuncs.sync_barrier(self.collectiveArgs, desc="runColl_begin")

        elapsedTimeNS = 0.0
        is_blocking = not self.collectiveArgs.asyncOp
        enable_comms = False if (comm_fn is None or comm_fn == self.backendFuncs.noop) else True
        enable_compute = False if (compute_fn is None or compute_fn == self.backendFuncs.noop) else True
        enable_comms_pair = False if (comm_fn_pair is None or comm_fn_pair == self.backendFuncs.noop) else True

        # for comms pair mode, force async comms for overlapping evaluation
        if enable_comms_pair:
            self.collectiveArgs.asyncOp = True
        for nIter in range(
            self.collectiveArgs.numWarmupIters + self.collectiveArgs.numIters
        ):
            if nIter == self.collectiveArgs.numWarmupIters:
                # Flush non-blocking ops to ensure warmup is really complete
                self.backendFuncs.complete_accel_ops(self.collectiveArgs)
                ensureTensorFlush(self.collectiveArgs.opTensor)
                if enable_comms_pair:
                    ensureTensorFlush(self.collectiveArgs.opTensor_pair)
                # Start measuring time after warmup iterations
                elapsedTimeNS = 0.0
                self.collectiveArgs.quant_time.reset()
                self.collectiveArgs.dequant_time.reset()
            # reset tensor values for data validation check
            if enable_comms:
                self.setTensorVal(self.collectiveArgs.opTensor)
            # for blocking mode, do barrier before starting collective
            if is_blocking:
                self.backendFuncs.sync_barrier(self.collectiveArgs)

            start = time.monotonic()  # available only in py3
            self.collectiveArgs.group = self.backendFuncs.get_next_group()
            comm_fn(self.collectiveArgs)
            # post another collecitve if on comms pair mode, otherwise it's noop
            self.collectiveArgs.group = self.backendFuncs.get_next_group()
            comm_fn_pair(self.collectiveArgs, pair=enable_comms_pair)

            if enable_compute:
                for _ in range(self.collectiveArgs.numComputePerColl):
                    # TODO: investigate the cache effect
                    # Flush the cache
                    # _ = torch.rand(6 * 1024 * 1024 // 4).float() * 2  # V100 6MB L2 cache
                    compute_fn(self.collectiveArgs)
            if is_blocking:  # should be sychronous, wait for the collective
                self.backendFuncs.complete_accel_ops(self.collectiveArgs)
            # Measuring time.
            elapsedTimeNS += (
                time.monotonic() - start
            ) * 1e9  # keeping time in NS, helps in divising data by nanosecond

        start = time.monotonic()  # available only in py3
        self.backendFuncs.complete_accel_ops(self.collectiveArgs)
        end = time.monotonic()  # available only in py3

        ensureTensorFlush(self.collectiveArgs.opTensor)
        if enable_comms_pair:
            ensureTensorFlush(self.collectiveArgs.opTensor_pair)

        elapsedTimeNS += (
            end - start
        ) * 1e9  # keeping time in NS, helps in divising data by nanoseconds

        memSize = self.backendFuncs.get_mem_size(self.collectiveArgs)

        avgIterNS, algBW = comms_utils.getAlgBW(
            elapsedTimeNS, memSize, self.collectiveArgs.numIters
        )
        busBW = self.backendFuncs.getBusBW(
            self.collectiveArgs.collective,
            algBW,
            self.collectiveArgs,
        )
        if enable_comms_pair:
            memSize_pair = self.backendFuncs.get_mem_size(
                self.collectiveArgs, pair=enable_comms_pair
            )
            memSize += memSize_pair

            _, algBW_pair = comms_utils.getAlgBW(
                elapsedTimeNS, memSize_pair, self.collectiveArgs.numIters
            )
            algBW += algBW_pair

            busBW += self.backendFuncs.getBusBW(
                self.collectiveArgs.collective_pair,
                algBW_pair,
                self.collectiveArgs,
            )

        self.backendFuncs.sync_barrier(self.collectiveArgs, desc="runColl_end")

        results = {
            "timeUS": avgIterNS / 1e3,
            "algBW": algBW,
            "busBW": busBW,
            "memSize": memSize,
        }
        return results