in train/comms/pt/commsTraceReplay.py [0:0]
def benchTime(self, commsParams):
"""
The json format is expecting to be either
{
"marker_stack": ["## all2all ##"]
"comms": "all_to_allv",
"in_msg_size": 10357149,
"out_msg_size": 23093760,
"in_split": [],
"out_split": [],
"dtype": "Int"
},
or w/o in/out_split
{
"marker_stack": ["## all2all ##"]
"comms": "all_reduce",
"in_msg_size": 1048576,
"out_msg_size": 1048576,
"dtype": "Int"
}
or wait/barrier
{
"marker_stack": ["## all2all ##"]
"comms": "wait",
}
NOTE:
- this format is subject to be changed anytime
- the unit of all size fields is # of elements (not bytes)
"""
# warm-up
if self.do_warm_up:
self.warmUpBench(commsParams)
# sync everything before starting real runs
self.backendFuncs.sync_barrier(self.collectiveArgs)
if self.backendFuncs.get_global_rank() == 0:
print(
f"\n+ {self.max_msg_cnt} messages in the trace...replaying (if present) {list(self.allowList)}"
)
for coll, sizes in self.collInMsgSizes.items():
logger.info(f"\t{coll}: {len(sizes)}")
coll_in_batch_num = 0
for cnt, curComm in enumerate(self.comms_trace[: self.max_msg_cnt]):
collName = paramToCommName(curComm["comms"])
if collName not in self.allowList:
continue
curBlocks = curComm["marker_stack"] if "marker_stack" in curComm else []
curBlockStack = (
" ".join(curBlocks) if len(curBlocks) > 0 else "Unamed/Unknown"
)
if self.backendFuncs.get_global_rank() == 0:
logger.debug(
f"[Rank {self.collectiveArgs.global_rank:3}] Replaying \n{str(curComm)}\n"
)
print(f"[{cnt} / {self.max_msg_cnt}]", end="\r")
# read fields and prepare the tensors
(
self.collectiveArgs.ipTensor,
self.collectiveArgs.opTensor,
) = self.prepComms(curComm, commsParams)
if self.colls_per_batch > 0 and coll_in_batch_num == 0:
batch_begin = time.monotonic()
(latency, global_latency) = self.runComms(collName, curBlockStack)
# calculating batch latency (batch defined by --colls-per-batch)
if collName == "wait" and self.colls_per_batch > 0:
coll_in_batch_num += 1
if coll_in_batch_num == self.colls_per_batch:
batch_latency = (
time.monotonic() - batch_begin
) * 1e3 # make it millisecond
coll_in_batch_num = 0
self.batchLat.append(batch_latency)
# perfom data validation check on the final opTensor
if self.is_blocking and commsParams.dcheck == 1 and collName not in ("wait","barrier"):
commsParams.collective = collName
commsParams.srcOrDst = curComm["root"] if "root" in curComm else 0
self.dcheck(commsParams, curComm["out_msg_size"], self.collectiveArgs.opTensor)
self.collLat[collName].append(latency)
curComm["seqnum"] = cnt
curComm["latency_us"] = latency
curComm["global_latency_us"] = global_latency
curComm["quant_us"] = self.collectiveArgs.quant_time.getTimeUS()
curComm["dequant_us"] = self.collectiveArgs.dequant_time.getTimeUS()
self.totalCommsLatency += latency
# Keep a copy of trace with performance (latency) and seqnum
self.traceWithPerf.append(curComm)
# categorized by the marker
for curBlock in curBlocks:
# elem_size = self.collectiveArgs.ipTensor.element_size()
self.comms_blocks[curBlock].append(curComm)
if self.backendFuncs.get_global_rank() == 0:
logger.info(
f"[{cnt} / {self.max_msg_cnt}] Replayed {collName} in block [{curBlockStack}]... {global_latency:.2f} us"
)
# make sure all ops are completed
self.backendFuncs.sync_barrier(self.collectiveArgs)
self.backendFuncs.clear_memory(self.collectiveArgs)