def _process_distributed_profiles()

in tb_plugin/torch_tb_profiler/profiler/loader.py [0:0]


    def _process_distributed_profiles(self, profiles, span):
        has_communication = True
        comm_node_lists = []
        for data in profiles:
            logger.debug("Processing profile data")
            # Set has_communication to False and disable distributed view if any one worker has no communication
            if data.has_communication and data.comm_node_list:
                comm_node_lists.append(data.comm_node_list)
                if len(comm_node_lists[-1]) != len(comm_node_lists[0]):
                    logger.error("Number of communication operation nodes don't match between workers in run: %s" % self.run_name)
                    has_communication = False
            else:
                has_communication = False
            logger.debug("Processing profile data finish")

        if not has_communication:
            logger.debug("There is no communication profile in this run.")
            return None

        worker_num = len(comm_node_lists)
        for i, node in enumerate(comm_node_lists[0]):
            kernel_range_size = len(node.kernel_ranges)
            # loop for all communication kernel ranges in order
            for j in range(kernel_range_size):
                min_range = sys.maxsize
                # For each kernel_range, find the minist between workers as the real communication time
                for k in range(worker_num):
                    kernel_ranges = comm_node_lists[k][i].kernel_ranges
                    if len(kernel_ranges) != kernel_range_size:
                        logger.error("Number of communication kernels don't match between workers in run: %s" % self.run_name)
                        has_communication = False
                        return None
                    if kernel_ranges:
                        if kernel_ranges[j][1] - kernel_ranges[j][0] < min_range:
                            min_range = kernel_ranges[j][1] - kernel_ranges[j][0]
                for k in range(worker_num):
                    kernel_range = comm_node_lists[k][i].kernel_ranges[j]
                    comm_node_lists[k][i].real_time_ranges.append((kernel_range[1] - min_range, kernel_range[1]))

        for data in profiles:
            data.communication_parse()

        generator = DistributedRunGenerator(profiles, span)
        profile = generator.generate_run_profile()
        return profile