in train/comms/pt/dlrm.py [0:0]
def reportBenchTime(self, global_rank, wamrupIters, measuredIters, world_size, curDevice):
if(measuredIters != 0):
all_timers = ['intermed_calc_length', 'mem_push_idx', 'intermed_bef_offset_xchg', 'offset_xchg', 'intermed_btw_offset_idx_xchg',
'idx_xchg', 'intermed_post_idx_xchg_sparse_dist', 'intermed_emb_lookup_to_a2a_start', 'fwd_a2a', 'intermed_fwd_a2a_grad_push',
'mem_push_gradients', 'bwd_top_ar', 'intermed_top_ar_end_to_bwd_a2a_start', 'bwd_a2a', 'intermed_bwd_a2a_bot_ar', 'bwd_bot_ar',
'iter_time', 'iter_data_prep', 'iter_fwd_a2a', 'iter_bwd_top_ar', 'iter_bwd_a2a']
# Each rank makes a list (2D tensor) of all the samples for each measured-region. Do the same for memory as well.
combined_latency_list = []
combined_memory_list = []
for cur_region in all_timers:
combined_latency_list.append(self.measured_regions[cur_region]['samples'])
combined_memory_list.append(self.measured_regions[cur_region]['memory'])
# All-gather to exchange the samples (memory and latency)
timeElapsedTensor = torch.tensor(combined_latency_list, device=curDevice)
tensor_list = [torch.ones_like(timeElapsedTensor) for _ in range(world_size)]
self.collectiveArgs.ipTensor = timeElapsedTensor
self.collectiveArgs.tensorList = tensor_list
self.collectiveArgs.asyncOp = False
self.collectiveArgs.dataSize = (
timeElapsedTensor.nelement() * timeElapsedTensor.element_size()
)
self.collectiveArgs.numElements = timeElapsedTensor.nelement()
self.backendFuncs.all_gather(self.collectiveArgs)
self.backendFuncs.complete_accel_ops(self.collectiveArgs)
memory_tensor = torch.tensor(combined_memory_list, device=curDevice)
memory_tensor_list = [torch.ones_like(memory_tensor) for _ in range(world_size)]
self.collectiveArgs.ipTensor = memory_tensor
self.collectiveArgs.tensorList = memory_tensor_list
self.backendFuncs.all_gather(self.collectiveArgs)
self.backendFuncs.complete_accel_ops(self.collectiveArgs)
sum_latency = 0.0
sum_mean_latency = 0.0
if(global_rank == 0):
cpu_tensor_latency = []
cpu_tensor_memory = []
for cur_region in range(world_size):
cpu_tensor_latency.append(tensor_list[cur_region].to('cpu'))
cpu_tensor_memory.append(memory_tensor_list[cur_region].to('cpu'))
res_mean_percentiles = []
res_percentiles = []
print("\t{}\t{:>36}\t{:>12}\t{:>12}\t{:>12}\t{:>12}\t{:>12}\t{:>12}".format("iters","region","memory (B)","Latency(us):min","p50","p75","p95","sum(p50)"))
for region_idx, cur_region in enumerate(all_timers):
# For each region, get data from different ranks. Compute percentiles for a given region.
all_rank_latency = []
all_rank_memory = []
all_rank_mean_latency = []
for cur_rank in range(world_size):
cur_rank_latency = cpu_tensor_latency[cur_rank][region_idx]
cur_rank_memory = cpu_tensor_memory[cur_rank][region_idx][wamrupIters:]
all_rank_latency.append(cur_rank_latency)
all_rank_memory.append(cur_rank_memory)
all_rank_mean_latency.append(torch.mean(cur_rank_latency))
all_rank_latency = torch.cat(all_rank_latency)
all_rank_memory = torch.cat(all_rank_memory)
latencyAcrossRanks = np.array(all_rank_latency)
min_lat = torch.min(all_rank_latency)
p50 = np.percentile(latencyAcrossRanks, 50)
p75 = np.percentile(latencyAcrossRanks, 75)
p95 = np.percentile(latencyAcrossRanks, 95)
mean_latencyAcrossRanks = np.array(all_rank_mean_latency)
mean_min_lat = min(all_rank_mean_latency)
mean_p50 = np.percentile(mean_latencyAcrossRanks, 50)
mean_p75 = np.percentile(mean_latencyAcrossRanks, 75)
mean_p95 = np.percentile(mean_latencyAcrossRanks, 95)
memoryAcrossRanks = np.array(all_rank_memory)
mem_p50 = np.percentile(memoryAcrossRanks, 50)
# Printing two sets of results --
# 1. Percentiles based on samples across all the ranks (so #samples = num_iterations * num_ranks)
# 2. Percentiles based on average latency at each rank (so #samples = num_ranks)
if('iter' not in cur_region):
sum_latency = sum_latency + p50
sum_mean_latency = sum_mean_latency + mean_p50
res_percentiles_line = "\t%d\t%36s\t%12s\t%12s\t%12s\t%12s\t%12s\t%12s" % (measuredIters, cur_region, '%d' % (mem_p50), '%.3f' % (min_lat),
'%.3f' % (p50), '%.3f' % (p75), '%.3f' % (p95), '%.3f' % (sum_latency))
res_mean_percentiles_line = "\t%d\t%36s\t%12s\t%12s\t%12s\t%12s\t%12s\t%12s" % (measuredIters, cur_region, '%d' % (mem_p50), '%.3f' % (mean_min_lat),
'%.3f' % (mean_p50), '%.3f' % (mean_p75), '%.3f' % (mean_p95), '%.3f' % (sum_mean_latency))
res_percentiles.append(res_percentiles_line)
res_mean_percentiles.append(res_mean_percentiles_line)
for cur_line in res_percentiles:
if('iter_time' in cur_line):
print("\n")
print(cur_line)
print("\t%d\t%36s\t%12s\t%12s\t%12s" % (measuredIters, "total_time", "N/A", "N/A", '%.3f' % (sum_latency)))
print("\n\n -----------------------------------------------------------------------------------------------------------------------------\n\n")
for cur_line in res_mean_percentiles:
if('iter_time' in cur_line):
print("\n")
print(cur_line)
print("\t%d\t%36s\t%12s\t%12s\t%12s" % (measuredIters, "total_time", "N/A", "N/A", '%.3f' % (sum_mean_latency)))
print("\n\n -----------------------------------------------------------------------------------------------------------------------------\n\n")