in train/comms/pt/comms.py [0:0]
def runColl(self, comm_fn=None, compute_fn=None, comm_fn_pair=None):
self.backendFuncs.complete_accel_ops(self.collectiveArgs, initOp=True)
self.backendFuncs.sync_barrier(self.collectiveArgs, desc="runColl_begin")
elapsedTimeNS = 0.0
is_blocking = not self.collectiveArgs.asyncOp
enable_comms = False if (comm_fn is None or comm_fn == self.backendFuncs.noop) else True
enable_compute = False if (compute_fn is None or compute_fn == self.backendFuncs.noop) else True
enable_comms_pair = False if (comm_fn_pair is None or comm_fn_pair == self.backendFuncs.noop) else True
# for comms pair mode, force async comms for overlapping evaluation
if enable_comms_pair:
self.collectiveArgs.asyncOp = True
for nIter in range(
self.collectiveArgs.numWarmupIters + self.collectiveArgs.numIters
):
if nIter == self.collectiveArgs.numWarmupIters:
# Flush non-blocking ops to ensure warmup is really complete
self.backendFuncs.complete_accel_ops(self.collectiveArgs)
ensureTensorFlush(self.collectiveArgs.opTensor)
if enable_comms_pair:
ensureTensorFlush(self.collectiveArgs.opTensor_pair)
# Start measuring time after warmup iterations
elapsedTimeNS = 0.0
self.collectiveArgs.quant_time.reset()
self.collectiveArgs.dequant_time.reset()
# reset tensor values for data validation check
if enable_comms:
self.setTensorVal(self.collectiveArgs.opTensor)
# for blocking mode, do barrier before starting collective
if is_blocking:
self.backendFuncs.sync_barrier(self.collectiveArgs)
start = time.monotonic() # available only in py3
self.collectiveArgs.group = self.backendFuncs.get_next_group()
comm_fn(self.collectiveArgs)
# post another collecitve if on comms pair mode, otherwise it's noop
self.collectiveArgs.group = self.backendFuncs.get_next_group()
comm_fn_pair(self.collectiveArgs, pair=enable_comms_pair)
if enable_compute:
for _ in range(self.collectiveArgs.numComputePerColl):
# TODO: investigate the cache effect
# Flush the cache
# _ = torch.rand(6 * 1024 * 1024 // 4).float() * 2 # V100 6MB L2 cache
compute_fn(self.collectiveArgs)
if is_blocking: # should be sychronous, wait for the collective
self.backendFuncs.complete_accel_ops(self.collectiveArgs)
# Measuring time.
elapsedTimeNS += (
time.monotonic() - start
) * 1e9 # keeping time in NS, helps in divising data by nanosecond
start = time.monotonic() # available only in py3
self.backendFuncs.complete_accel_ops(self.collectiveArgs)
end = time.monotonic() # available only in py3
ensureTensorFlush(self.collectiveArgs.opTensor)
if enable_comms_pair:
ensureTensorFlush(self.collectiveArgs.opTensor_pair)
elapsedTimeNS += (
end - start
) * 1e9 # keeping time in NS, helps in divising data by nanoseconds
memSize = self.backendFuncs.get_mem_size(self.collectiveArgs)
avgIterNS, algBW = comms_utils.getAlgBW(
elapsedTimeNS, memSize, self.collectiveArgs.numIters
)
busBW = self.backendFuncs.getBusBW(
self.collectiveArgs.collective,
algBW,
self.collectiveArgs,
)
if enable_comms_pair:
memSize_pair = self.backendFuncs.get_mem_size(
self.collectiveArgs, pair=enable_comms_pair
)
memSize += memSize_pair
_, algBW_pair = comms_utils.getAlgBW(
elapsedTimeNS, memSize_pair, self.collectiveArgs.numIters
)
algBW += algBW_pair
busBW += self.backendFuncs.getBusBW(
self.collectiveArgs.collective_pair,
algBW_pair,
self.collectiveArgs,
)
self.backendFuncs.sync_barrier(self.collectiveArgs, desc="runColl_end")
results = {
"timeUS": avgIterNS / 1e3,
"algBW": algBW,
"busBW": busBW,
"memSize": memSize,
}
return results