in train/comms/pt/comms.py [0:0]
def benchTime(self, index, commsParams, backendFuncs):
# Get NW stack specific parameters
(
local_rank,
global_rank,
world_size,
group,
curDevice,
curHwDevice,
allSizes,
computeFunc,
) = self.initCollectiveArgs(commsParams)
backendFuncs.sync_barrier(self.collectiveArgs)
if global_rank == 0:
self.printPreamble(commsParams)
for curSize in allSizes:
results = {}
timeUsElapsedList = []
quantTimeElapsedList = []
dequantTimeElapsedList = []
numElements = int(curSize // commsParams.element_size)
collectiveFunc = self.backendFuncs.noop
collectiveFunc_pair = self.backendFuncs.noop
if (
commsParams.mode != "compute"
): # comms specific initializations if not in compute-only mode
# set corresponding function pointers
if commsParams.collective != "pt2pt":
collectiveFunc = backendFuncs.collectiveFunc[commsParams.collective]
(
self.collectiveArgs.ipTensor,
self.collectiveArgs.opTensor,
) = self.prepComm(
curComm={
"in_msg_size": numElements,
"out_msg_size": numElements,
"world_size": world_size,
},
commsParams=commsParams,
)
# Setup the arguments.
self.collectiveArgs.dataSize = curSize
self.collectiveArgs.numElements = numElements
self.collectiveArgs.waitObj = []
results["numElements"] = numElements
if (
commsParams.pair and commsParams.mode != "compute"
): # comms-pair specific initializations if not in compute-only mode:
# set corresponding function pointers
collectiveFunc_pair = backendFuncs.collectiveFunc[
commsParams.collective_pair
]
# TODO: allow user to set specific size
# Setup the arguments.
self.collectiveArgs.dataSize_pair = curSize
self.collectiveArgs.numElements_pair = int(
self.collectiveArgs.dataSize_pair // commsParams.element_size
)
results["numElements_pair"] = self.collectiveArgs.numElements_pair
(
self.collectiveArgs.ipTensor_pair,
self.collectiveArgs.opTensor_pair,
) = self.prepComm(
curComm={
"in_msg_size": self.collectiveArgs.numElements_pair,
"out_msg_size": self.collectiveArgs.numElements_pair,
"world_size": world_size,
},
commsParams=commsParams,
)
# self.collectiveArgs has all the information on the experiment.
if commsParams.collective == "pt2pt":
results.update(self.runPt2Pt())
timeUsElapsedList = [
np.mean(np.array(results["pingPerIterNS"])) / 1e3,
np.mean(np.array(results["pingPongPerIterNS"])) / 1e3,
results["avgUniBW"],
results["avgBiBW"],
] # time in US
if (
global_rank in self.collectiveArgs.src_ranks
or global_rank in self.collectiveArgs.dst_ranks
):
logger.debug(timeUsElapsedList)
else:
results.update(
self.runColl(
comm_fn=collectiveFunc,
compute_fn=computeFunc,
comm_fn_pair=collectiveFunc_pair,
)
)
timeUsElapsedList = [results["timeUS"]]
# perfom data validation check on the final opTensor
if commsParams.dcheck == 1:
self.dcheck(commsParams, curSize, self.collectiveArgs.opTensor)
backendFuncs.clear_memory(self.collectiveArgs)
# gather quantization overhead if enabled
if commsParams.bitwidth < 32:
# calculate average (de-)quantization overhead
results["quantTimeUS"] = (
self.collectiveArgs.quant_time.getTimeUS()
/ self.collectiveArgs.numIters
)
results["dequantTimeUS"] = (
self.collectiveArgs.dequant_time.getTimeUS()
/ self.collectiveArgs.numIters
)
quantTimeElapsedList.append(results["quantTimeUS"])
dequantTimeElapsedList.append(results["dequantTimeUS"])
logger.debug(quantTimeElapsedList)
quantTimeElapsedList = self.gatherBenchTime(
self.collectiveArgs, commsParams, quantTimeElapsedList
)
dequantTimeElapsedList = self.gatherBenchTime(
self.collectiveArgs, commsParams, dequantTimeElapsedList
)
# gather and report performance to stdout
tensorList = self.gatherBenchTime(
self.collectiveArgs, commsParams, timeUsElapsedList
)
if global_rank == 0:
self.reportBenchTime(
commsParams,
results,
tensorList,
quantTimeElapsedList,
dequantTimeElapsedList,
)
self.backendFuncs.sync_barrier(
self.collectiveArgs, desc=f"curSize_{curSize}"
)
comms_utils.clearQuantCommCtx(self.collectiveArgs)
# wait rank 0 reports results to avoid other ranks mess up the output
self.backendFuncs.sync_barrier(self.collectiveArgs, "benchtime")