def reportBenchTime()

in train/comms/pt/dlrm.py [0:0]


    def reportBenchTime(self, global_rank, wamrupIters, measuredIters, world_size, curDevice):
        if(measuredIters != 0):
            all_timers = ['intermed_calc_length', 'mem_push_idx', 'intermed_bef_offset_xchg', 'offset_xchg', 'intermed_btw_offset_idx_xchg',
            'idx_xchg', 'intermed_post_idx_xchg_sparse_dist', 'intermed_emb_lookup_to_a2a_start', 'fwd_a2a', 'intermed_fwd_a2a_grad_push',
            'mem_push_gradients', 'bwd_top_ar', 'intermed_top_ar_end_to_bwd_a2a_start', 'bwd_a2a', 'intermed_bwd_a2a_bot_ar', 'bwd_bot_ar',
            'iter_time', 'iter_data_prep', 'iter_fwd_a2a', 'iter_bwd_top_ar', 'iter_bwd_a2a']

            # Each rank makes a list (2D tensor) of all the samples for each measured-region. Do the same for memory as well.
            combined_latency_list = []
            combined_memory_list = []
            for cur_region in all_timers:
                combined_latency_list.append(self.measured_regions[cur_region]['samples'])
                combined_memory_list.append(self.measured_regions[cur_region]['memory'])

            # All-gather to exchange the samples (memory and latency)
            timeElapsedTensor = torch.tensor(combined_latency_list, device=curDevice)
            tensor_list = [torch.ones_like(timeElapsedTensor) for _ in range(world_size)]
            self.collectiveArgs.ipTensor = timeElapsedTensor
            self.collectiveArgs.tensorList = tensor_list
            self.collectiveArgs.asyncOp = False
            self.collectiveArgs.dataSize = (
                timeElapsedTensor.nelement() * timeElapsedTensor.element_size()
            )
            self.collectiveArgs.numElements = timeElapsedTensor.nelement()
            self.backendFuncs.all_gather(self.collectiveArgs)
            self.backendFuncs.complete_accel_ops(self.collectiveArgs)

            memory_tensor = torch.tensor(combined_memory_list, device=curDevice)
            memory_tensor_list = [torch.ones_like(memory_tensor) for _ in range(world_size)]
            self.collectiveArgs.ipTensor = memory_tensor
            self.collectiveArgs.tensorList = memory_tensor_list
            self.backendFuncs.all_gather(self.collectiveArgs)
            self.backendFuncs.complete_accel_ops(self.collectiveArgs)

            sum_latency = 0.0
            sum_mean_latency = 0.0
            if(global_rank == 0):
                cpu_tensor_latency = []
                cpu_tensor_memory = []
                for cur_region in range(world_size):
                    cpu_tensor_latency.append(tensor_list[cur_region].to('cpu'))
                    cpu_tensor_memory.append(memory_tensor_list[cur_region].to('cpu'))

                res_mean_percentiles = []
                res_percentiles = []
                print("\t{}\t{:>36}\t{:>12}\t{:>12}\t{:>12}\t{:>12}\t{:>12}\t{:>12}".format("iters","region","memory (B)","Latency(us):min","p50","p75","p95","sum(p50)"))
                for region_idx, cur_region in enumerate(all_timers):
                    # For each region, get data from different ranks. Compute percentiles for a given region.
                    all_rank_latency = []
                    all_rank_memory = []
                    all_rank_mean_latency = []
                    for cur_rank in range(world_size):
                        cur_rank_latency = cpu_tensor_latency[cur_rank][region_idx]
                        cur_rank_memory = cpu_tensor_memory[cur_rank][region_idx][wamrupIters:]

                        all_rank_latency.append(cur_rank_latency)
                        all_rank_memory.append(cur_rank_memory)
                        all_rank_mean_latency.append(torch.mean(cur_rank_latency))

                    all_rank_latency = torch.cat(all_rank_latency)
                    all_rank_memory = torch.cat(all_rank_memory)

                    latencyAcrossRanks = np.array(all_rank_latency)
                    min_lat = torch.min(all_rank_latency)
                    p50 = np.percentile(latencyAcrossRanks, 50)
                    p75 = np.percentile(latencyAcrossRanks, 75)
                    p95 = np.percentile(latencyAcrossRanks, 95)

                    mean_latencyAcrossRanks = np.array(all_rank_mean_latency)
                    mean_min_lat = min(all_rank_mean_latency)
                    mean_p50 = np.percentile(mean_latencyAcrossRanks, 50)
                    mean_p75 = np.percentile(mean_latencyAcrossRanks, 75)
                    mean_p95 = np.percentile(mean_latencyAcrossRanks, 95)

                    memoryAcrossRanks = np.array(all_rank_memory)
                    mem_p50 = np.percentile(memoryAcrossRanks, 50)
                    # Printing two sets of results --
                    # 1. Percentiles based on samples across all the ranks (so #samples = num_iterations * num_ranks)
                    # 2. Percentiles based on average latency at each rank (so #samples = num_ranks)

                    if('iter' not in cur_region):
                        sum_latency = sum_latency + p50
                        sum_mean_latency = sum_mean_latency + mean_p50
                    res_percentiles_line = "\t%d\t%36s\t%12s\t%12s\t%12s\t%12s\t%12s\t%12s" % (measuredIters, cur_region, '%d' % (mem_p50), '%.3f' % (min_lat),
                    '%.3f' % (p50), '%.3f' % (p75), '%.3f' % (p95), '%.3f' % (sum_latency))
                    res_mean_percentiles_line = "\t%d\t%36s\t%12s\t%12s\t%12s\t%12s\t%12s\t%12s" % (measuredIters, cur_region, '%d' % (mem_p50), '%.3f' % (mean_min_lat),
                    '%.3f' % (mean_p50), '%.3f' % (mean_p75), '%.3f' % (mean_p95), '%.3f' % (sum_mean_latency))

                    res_percentiles.append(res_percentiles_line)
                    res_mean_percentiles.append(res_mean_percentiles_line)

                for cur_line in res_percentiles:
                    if('iter_time' in cur_line):
                        print("\n")
                    print(cur_line)

                print("\t%d\t%36s\t%12s\t%12s\t%12s" % (measuredIters, "total_time", "N/A", "N/A", '%.3f' % (sum_latency)))
                print("\n\n -----------------------------------------------------------------------------------------------------------------------------\n\n")
                for cur_line in res_mean_percentiles:
                    if('iter_time' in cur_line):
                        print("\n")
                    print(cur_line)
                print("\t%d\t%36s\t%12s\t%12s\t%12s" % (measuredIters, "total_time", "N/A", "N/A", '%.3f' % (sum_mean_latency)))
                print("\n\n -----------------------------------------------------------------------------------------------------------------------------\n\n")