in train/comms/pt/comms.py [0:0]
def reportBenchTimePt2Pt(self, commsParams, resultsAcrossRanks, results):
pingLatencyAcrossRanks = []
pingPongLatencyAcrossRanks = []
uniBWAcrossRanks = []
biBWAcrossRanks = []
# idx = 0
for curRankTensor in resultsAcrossRanks:
pingLatencyAcrossRanks.append(curRankTensor[0].item())
pingPongLatencyAcrossRanks.append(curRankTensor[1].item())
uniBWAcrossRanks.append(curRankTensor[2].item())
biBWAcrossRanks.append(curRankTensor[3].item())
pingLatencyAcrossRanks = np.array(pingLatencyAcrossRanks)
pingPongLatencyAcrossRanks = np.array(pingPongLatencyAcrossRanks)
uniBWAcrossRanks = np.array(uniBWAcrossRanks)
biBWAcrossRanks = np.array(biBWAcrossRanks)
# Include only communicating ranks
commRanks = self.collectiveArgs.src_ranks + self.collectiveArgs.dst_ranks
pingLatencyAcrossCommRanks = pingLatencyAcrossRanks[commRanks]
pingPongLatencyAcrossCommRanks = pingPongLatencyAcrossRanks[commRanks]
uniBWAcrossCommRanks = uniBWAcrossRanks[commRanks]
biBWAcrossCommRanks = biBWAcrossRanks[commRanks]
logger.debug(
"Ping latency across communicating ranks (%s): %s"
% (commRanks, pingLatencyAcrossCommRanks)
)
logger.debug(
"PingPong latency across communicating ranks (%s): %s"
% (commRanks, pingPongLatencyAcrossCommRanks)
)
logger.debug(
"UniBW across all communicating ranks (%s): %s"
% (commRanks, uniBWAcrossCommRanks)
)
logger.debug(
"BiBW across all communicating ranks (%s): %s"
% (commRanks, biBWAcrossCommRanks)
)
avgUniBW = np.mean(uniBWAcrossCommRanks)
avgBiBW = np.mean(biBWAcrossCommRanks)
totalUniBW = np.sum(uniBWAcrossCommRanks) / 2
totalBiBW = np.sum(biBWAcrossCommRanks) / 2
ping_p50 = np.percentile(pingLatencyAcrossCommRanks, 50)
ping_p75 = np.percentile(pingLatencyAcrossCommRanks, 75)
ping_p95 = np.percentile(pingLatencyAcrossCommRanks, 95)
ping_pong_p50 = np.percentile(pingPongLatencyAcrossCommRanks, 50)
ping_pong_p75 = np.percentile(pingPongLatencyAcrossCommRanks, 75)
ping_pong_p95 = np.percentile(pingPongLatencyAcrossCommRanks, 95)
print(
"\tCOMMS-RES{:>15}{:>20}{:>10}{:>10}{:>25}{:>10}{:>10}{:>15}{:>15}{:>18}{:>18}".format(
results["memSize"],
str("%.1f" % (ping_p50)),
str("%.1f" % (ping_p75)),
str("%.1f" % (ping_p95)),
str("%.1f" % (ping_pong_p50)),
str("%.1f" % (ping_pong_p75)),
str("%.1f" % (ping_pong_p95)),
str("%.3f" % (avgUniBW)),
str("%.3f" % (avgBiBW)),
str("%.3f" % (totalUniBW)),
str("%.3f" % (totalBiBW)),
)
)