in train/comms/pt/comms.py [0:0]
def reportBenchTimeColl(self, commsParams, results, tensorList):
if commsParams.backend == "xla":
latencyAcrossRanks = torch.transpose(tensorList.view(-1, 1), 0, 1)[0]
latencyAcrossRanks = latencyAcrossRanks.cpu().detach().numpy()
else:
if isinstance(tensorList, list):
tensorList = [t.cpu().detach().numpy() for t in tensorList]
latencyAcrossRanks = np.array(tensorList)
logger.debug(f"Latency across all ranks: {latencyAcrossRanks}")
# Include only communicating ranks
if self.collectiveArgs.collective == "multicast":
commRanks = [self.collectiveArgs.srcOrDst] + self.collectiveArgs.dst_ranks
elif self.collectiveArgs.collective == "incast":
commRanks = [self.collectiveArgs.srcOrDst] + self.collectiveArgs.src_ranks
else:
commRanks = range(self.collectiveArgs.world_size)
latencyAcrossCommRanks = latencyAcrossRanks[commRanks]
logger.debug(
"Latency across communicating ranks (%s): %s"
% (commRanks, latencyAcrossCommRanks)
)
p50 = np.percentile(latencyAcrossCommRanks, 50)
p75 = np.percentile(latencyAcrossCommRanks, 75)
p95 = np.percentile(latencyAcrossCommRanks, 95)
minlat = np.amin(latencyAcrossCommRanks)
maxlat = np.amax(latencyAcrossCommRanks)
# adjust busBW
busBW = results["busBW"] * (commsParams.bitwidth / 32.0)
if not self.collectiveArgs.pair:
print(
"\tCOMMS-RES{:>15}{:>18}{:>18}{:>12}{:>12}{:>12}{:>12}{:>15}{:>12}".format(
results["memSize"],
str("%d" % (results["numElements"])),
str("%.1f" % (p50)),
str("%.1f" % (p75)),
str("%.1f" % (p95)),
str("%.1f" % (minlat)),
str("%.1f" % (maxlat)),
str("%.3f" % (results["algBW"])),
str("%.3f" % (busBW)),
)
)
else:
# convernt to # of elements per rank
if commsParams.collective_pair in ("all_to_all", "all_to_allv"):
results["numElements_pair"] = int(
results["numElements_pair"]
// commsParams.comms_world_info.world_size
)
print(
"\tCOMMS-RES{:>15}{:>18}{:>22}{:>18}{:>12}{:>12}{:>12}{:>12}{:>15}{:>12}".format(
results["memSize"],
str("%d" % (results["numElements"])),
str("%d" % (results["numElements_pair"])),
str("%.1f" % (p50)),
str("%.1f" % (p75)),
str("%.1f" % (p95)),
str("%.1f" % (minlat)),
str("%.1f" % (maxlat)),
str("%.3f" % (results["algBW"])),
str("%.3f" % (busBW)),
)
)