in tb_plugin/torch_tb_profiler/profiler/loader.py [0:0]
def _process_distributed_profiles(self, profiles, span):
has_communication = True
comm_node_lists = []
for data in profiles:
logger.debug("Processing profile data")
# Set has_communication to False and disable distributed view if any one worker has no communication
if data.has_communication and data.comm_node_list:
comm_node_lists.append(data.comm_node_list)
if len(comm_node_lists[-1]) != len(comm_node_lists[0]):
logger.error("Number of communication operation nodes don't match between workers in run: %s" % self.run_name)
has_communication = False
else:
has_communication = False
logger.debug("Processing profile data finish")
if not has_communication:
logger.debug("There is no communication profile in this run.")
return None
worker_num = len(comm_node_lists)
for i, node in enumerate(comm_node_lists[0]):
kernel_range_size = len(node.kernel_ranges)
# loop for all communication kernel ranges in order
for j in range(kernel_range_size):
min_range = sys.maxsize
# For each kernel_range, find the minist between workers as the real communication time
for k in range(worker_num):
kernel_ranges = comm_node_lists[k][i].kernel_ranges
if len(kernel_ranges) != kernel_range_size:
logger.error("Number of communication kernels don't match between workers in run: %s" % self.run_name)
has_communication = False
return None
if kernel_ranges:
if kernel_ranges[j][1] - kernel_ranges[j][0] < min_range:
min_range = kernel_ranges[j][1] - kernel_ranges[j][0]
for k in range(worker_num):
kernel_range = comm_node_lists[k][i].kernel_ranges[j]
comm_node_lists[k][i].real_time_ranges.append((kernel_range[1] - min_range, kernel_range[1]))
for data in profiles:
data.communication_parse()
generator = DistributedRunGenerator(profiles, span)
profile = generator.generate_run_profile()
return profile