in train/comms/pt/commsTraceReplay.py [0:0]
def reportBenchTime(self, commsParams):
# TODO:
# 1) dry run: output some statistics, e.g., # of msgs, distribtuion of sizes (max, min, avg, p50, p95...ect)
# 2) normal run: output 1) as well as perf. breakdown (e.g., a2a latencies at different phase, some percentages...ect)
# some basic stats
print(
f"\n+++++ {len(self.comms_trace)} msgs recorded in {self.trace_file} +++++\n"
)
for curBlock, blockComms in self.comms_blocks.items():
lat_list = []
if not self.is_dry_run:
lat_list = [comm["latency_us"] for comm in blockComms]
Lats = np.array(lat_list)
logger.info(
f"+ {len(blockComms)} comms in block {curBlock}: {Lats.sum():.2f} us in total"
)
logger.info("\n{} Message size Statistcs {}".format("=" * 20, "=" * 20))
for (name, collMsgs) in self.collInMsgSizes.items():
# input tensor
msgSizes = np.array(collMsgs)
print("-" * 50)
print(f"+ {len(msgSizes)} {name}")
print("-" * 50)
print(
f"Size of Input tensors (bytes)\n {'Total (MB)':>10} {'Max.':>15} {'Min.':>10} {'Average':>13} {'p50':>13} {'p95':>13}"
)
print(
"{:>10.2f} {:15.2f} {:10.2f} {:15.2f} {:15.2f} {:15.2f}".format(
msgSizes.sum() / 1024 / 1024,
msgSizes.max(),
msgSizes.min(),
np.average(msgSizes),
np.percentile(msgSizes, 50),
np.percentile(msgSizes, 95),
)
)
logger.debug(f" - Used sizes: {sorted(self.collInUniMsgSizes[name])}")
# output tensor
msgSizes = np.array(self.collOutMsgSizes[name])
print(
f"Size of Output tensors (bytes)\n {'Total (MB)':>10} {'Max.':>15} {'Min.':>10} {'Average':>13} {'p50':>13} {'p95':>13}"
)
print(
"{:>10.2f} {:15.2f} {:10.2f} {:15.2f} {:15.2f} {:15.2f}".format(
msgSizes.sum() / 1024 / 1024,
msgSizes.max(),
msgSizes.min(),
np.average(msgSizes),
np.percentile(msgSizes, 50),
np.percentile(msgSizes, 95),
)
)
logger.debug(f" - Used sizes: {sorted(self.collOutUniMsgSizes[name])}")
if not self.is_dry_run:
print("\n{} Performance of replayed comms {}".format("=" * 20, "=" * 20))
for (coll, lats) in self.collLat.items():
if len(lats) == 0:
continue
Lat = np.array(lats)
print(
"{}\n Replayed {} {} ({:.2f}%): \n{}".format(
"-" * 50,
len(lats),
coll,
(Lat.sum() / self.totalCommsLatency) * 100,
"-" * 50,
)
)
print(
f"Latency (us)\n {'Total':>10} {'Max.':>10} {'Min.':>10} {'Average':>10} {'p50':>10} {'p95':>10}"
)
print(
" {:10.2f} {:10.2f} {:10.2f} {:10.2f} {:10.2f} {:10.2f}".format(
Lat.sum(),
Lat.max(),
Lat.min(),
np.average(Lat),
np.percentile(Lat, 50),
np.percentile(Lat, 95),
)
)
msgSizeAndLatency = (
tuple(
zip(lats, self.collInMsgSizes[coll], self.collOutMsgSizes[coll])
)
if coll in self.collInMsgSizes
else lats
)
logger.debug(f"Latency and size of First ten: {msgSizeAndLatency[:10]}")
if self.colls_per_batch > 0:
print("\n{} Batch Latency Performance {}".format("=" * 20, "=" * 20))
BatchLat = np.array(self.batchLat)
print(
f"Batch Latency (ms)\n {'Total':>10} {'Max.':>10} {'Min.':>10} {'Average':>10} {'p50':>10} {'p95':>10}"
)
print(
" {:10.2f} {:10.2f} {:10.2f} {:10.2f} {:10.2f} {:10.2f}".format(
BatchLat.sum(),
BatchLat.max(),
BatchLat.min(),
np.average(BatchLat),
np.percentile(BatchLat, 50),
np.percentile(BatchLat, 95),
)
)