def reportBenchTimeColl()

in train/comms/pt/comms.py [0:0]


    def reportBenchTimeColl(self, commsParams, results, tensorList):
        if commsParams.backend == "xla":
            latencyAcrossRanks = torch.transpose(tensorList.view(-1, 1), 0, 1)[0]
            latencyAcrossRanks = latencyAcrossRanks.cpu().detach().numpy()
        else:
            if isinstance(tensorList, list):
                tensorList = [t.cpu().detach().numpy() for t in tensorList]
            latencyAcrossRanks = np.array(tensorList)

        logger.debug(f"Latency across all ranks: {latencyAcrossRanks}")

        # Include only communicating ranks
        if self.collectiveArgs.collective == "multicast":
            commRanks = [self.collectiveArgs.srcOrDst] + self.collectiveArgs.dst_ranks
        elif self.collectiveArgs.collective == "incast":
            commRanks = [self.collectiveArgs.srcOrDst] + self.collectiveArgs.src_ranks
        else:
            commRanks = range(self.collectiveArgs.world_size)

        latencyAcrossCommRanks = latencyAcrossRanks[commRanks]
        logger.debug(
            "Latency across communicating ranks (%s): %s"
            % (commRanks, latencyAcrossCommRanks)
        )

        p50 = np.percentile(latencyAcrossCommRanks, 50)
        p75 = np.percentile(latencyAcrossCommRanks, 75)
        p95 = np.percentile(latencyAcrossCommRanks, 95)
        minlat = np.amin(latencyAcrossCommRanks)
        maxlat = np.amax(latencyAcrossCommRanks)

        # adjust busBW
        busBW = results["busBW"] * (commsParams.bitwidth / 32.0)

        if not self.collectiveArgs.pair:
            print(
                "\tCOMMS-RES{:>15}{:>18}{:>18}{:>12}{:>12}{:>12}{:>12}{:>15}{:>12}".format(
                    results["memSize"],
                    str("%d" % (results["numElements"])),
                    str("%.1f" % (p50)),
                    str("%.1f" % (p75)),
                    str("%.1f" % (p95)),
                    str("%.1f" % (minlat)),
                    str("%.1f" % (maxlat)),
                    str("%.3f" % (results["algBW"])),
                    str("%.3f" % (busBW)),
                )
            )
        else:
            # convernt to # of elements per rank
            if commsParams.collective_pair in ("all_to_all", "all_to_allv"):
                results["numElements_pair"] = int(
                    results["numElements_pair"]
                    // commsParams.comms_world_info.world_size
                )
            print(
                "\tCOMMS-RES{:>15}{:>18}{:>22}{:>18}{:>12}{:>12}{:>12}{:>12}{:>15}{:>12}".format(
                    results["memSize"],
                    str("%d" % (results["numElements"])),
                    str("%d" % (results["numElements_pair"])),
                    str("%.1f" % (p50)),
                    str("%.1f" % (p75)),
                    str("%.1f" % (p95)),
                    str("%.1f" % (minlat)),
                    str("%.1f" % (maxlat)),
                    str("%.3f" % (results["algBW"])),
                    str("%.3f" % (busBW)),
                )
            )