in train/comms/pt/dlrm.py [0:0]
def benchTime(self, global_rank, world_size, timers, mConfig, curDevice, curDeviceData, args):
memSizes = self.getMemSizes(curDeviceData)
for batchNum, (_, lS_o, lS_i, _) in enumerate(mConfig.train_ld):
timers['iter_start'] = time.monotonic()
curIterSparseFeatures = SparseFeatures(mConfig.num_sparse_fea,
self.expt_config['mini_batch_size'],
lS_o,
lS_i,
curDevice, global_rank,
self.backendFuncs,
self.collectiveArgs)
# Exchange it among all the ranks, so that we have indices and offsets of only the tables we have but for the global mini-batch, not just local mini-batch.s
g_offsets, g_indices = self.SparseDataDist(mConfig.n_emb_per_rank, curIterSparseFeatures, global_rank, world_size, timers)
# Begin with reading the embedding table.
self.backendFuncs.complete_accel_ops(self.collectiveArgs, initOp=True)
timers['bef_emb_lookup'] = time.monotonic()
ly = self.paramNN.apply_emb(g_offsets, g_indices, curDeviceData['embedLayers'], mixed_dim=self.mixedDimFlag)
# Start with fwd pass all-to-all, this is blocking communication.
a2a_req = ""
B = ""
tempB = ""
emb_memory = ly.nelement() * ly.element_size()
if(not self.mixedDimFlag):
a2a_req = self.alltoallv(ly, global_rank, mConfig.dims_sum_per_rank, mConfig.n_emb_per_rank)
B = a2a_req.wait()
tempB = torch.cat(B, dim=1)
self.collectiveArgs.timers['grad_push_start'] = time.monotonic()
C = tempB
self.backendFuncs.complete_accel_ops(self.collectiveArgs, initOp=True) # else data won't actually be moved, evidently!
if(args.perf_debug):
self.backendFuncs.complete_accel_ops(self.collectiveArgs, initOp=True)
self.backendFuncs.barrier(self.collectiveArgs)
#back-prop: top layer, non-blocking between the top-layers.
timers['bwd_top_ar_start'] = time.monotonic()
for curLayerIdx in range(len(curDeviceData['topLayers'])):
# Prepare collective arguments
self.collectiveArgs.ipTensor = curDeviceData['topLayers'][curLayerIdx]
self.collectiveArgs.asyncOp = True
self.collectiveArgs.op = self.backendFuncs.get_reduce_op('sum')
self.backendFuncs.all_reduce(self.collectiveArgs)
# Prepare communication details, logging to understand performance.
self.commDetails.append(
{
"comms" : "all_reduce",
"msg_size" : curDeviceData['topLayers'][curLayerIdx].nelement() * curDeviceData['topLayers'][curLayerIdx].element_size(),
"dtype" : str(curDeviceData['topLayers'][curLayerIdx].dtype),
}
)
self.backendFuncs.complete_accel_ops(self.collectiveArgs, initOp=True)
timers['bwd_top_ar_end'] = time.monotonic()
if(args.perf_debug):
self.backendFuncs.complete_accel_ops(self.collectiveArgs, initOp=True)
self.backendFuncs.barrier(self.collectiveArgs)
# back-prop: embedding update, blocking, since we are waiting for it to complete.
if(self.mixedDimFlag):
self.collectiveArgs.timers['bwd_a2a_start'] = time.monotonic()
tempB.backward(C)
self.collectiveArgs.timers['bwd_a2a_end'] = time.monotonic()
self.measured_regions['bwd_a2a']['memory'].append(emb_memory) # this is not quite right in case of , just ensuring that we have a non-zero entry for ads-feeds model
else:
tempB.backward(C)
if(args.perf_debug):
self.backendFuncs.complete_accel_ops(self.collectiveArgs, initOp=True)
self.backendFuncs.barrier(self.collectiveArgs)
#back-prop: bottom layer, non-blocking between the layers.
timers['bwd_bot_ar_start'] = time.monotonic()
for curLayerIdx in range(len(curDeviceData['botLayers'])):
self.collectiveArgs.ipTensor = curDeviceData['botLayers'][curLayerIdx]
self.collectiveArgs.asyncOp = True
self.collectiveArgs.op = self.backendFuncs.get_reduce_op('sum')
self.backendFuncs.all_reduce(self.collectiveArgs)
self.commDetails.append(
{
"comms" : "all_reduce",
"msg_size" : curDeviceData['botLayers'][curLayerIdx].nelement() * curDeviceData['botLayers'][curLayerIdx].element_size(),
"dtype" : str(curDeviceData['botLayers'][curLayerIdx].dtype),
}
)
self.backendFuncs.complete_accel_ops(self.collectiveArgs, initOp=True)
timers['bwd_bot_ar_end'] = time.monotonic()
self.measured_regions['bwd_top_ar']['memory'].append(sum(memSizes['top']))
self.measured_regions['bwd_bot_ar']['memory'].append(sum(memSizes['bot']))
if(batchNum >= self.expt_config['warmup_batches']):
self.computeTimes(timers)
self.intermed_region_memory(timers)