in tb_plugin/torch_tb_profiler/profiler/event_parser.py [0:0]
def _parse_node(self, event, corrid_to_device, corrid_to_runtime, externalid_to_runtime, tid2list, tid2zero_rt_list):
corrid = event.correlation_id
tid = event.tid
if event.type in [EventTypes.KERNEL, EventTypes.MEMCPY, EventTypes.MEMSET]:
self.used_devices.add(event.pid)
device_node = DeviceNode.create(event)
if corrid in corrid_to_runtime:
rt_node = corrid_to_runtime[corrid] # Don't pop it because it may be used by next kernel.
if rt_node.device_nodes is None:
rt_node.device_nodes = []
rt_node.device_nodes.append(device_node)
# Check the external_id
if rt_node.external_id != device_node.external_id:
logger.warning("Runtime and Device-op have same correlation id %s but with different external id! (runtime external_id, device external_id): (%s, %s)" %
(corrid, rt_node.external_id, device_node.external_id))
else:
corrid_to_device[corrid].append(device_node)
self.device_node_list.append(device_node)
elif event.type == EventTypes.RUNTIME:
device_nodes = corrid_to_device.pop(corrid, None)
rt_node = RuntimeNode.create(event, device_nodes)
corrid_to_runtime[corrid] = rt_node
externalid_to_runtime[rt_node.external_id].append(rt_node)
# Some runtimes has external_id 0, which will not be correlated to any operator.
# So get them and attach them to root node.
if rt_node.external_id == 0:
tid2zero_rt_list[tid].append(rt_node)
self.runtime_node_list.append(rt_node)
# check the external_id
if device_nodes:
for device_node in device_nodes:
if rt_node.external_id != device_node.external_id:
logger.warning("Runtime and Device-op have same correlation id %s but with different external id! (rt external_id, device external_id): (%s, %s)" %
(corrid, rt_node.external_id, device_node.external_id))
elif event.type in [EventTypes.PYTHON, EventTypes.OPERATOR, EventTypes.PROFILER_STEP]:
if event.type == EventTypes.PROFILER_STEP:
op_node = ProfilerStepNode.create(event)
else:
op_node = OperatorNode.create(event)
if event.name in NcclOpNameSet or event.name in GlooOpNameSet:
comm_node = CommunicationNode.create(event)
if event.name in NcclOpNameSet:
self.comm_lib.add(CommLibTypes.Nccl)
if event.name in GlooOpNameSet:
self.comm_lib.add(CommLibTypes.Gloo)
ts = event.ts
dur = event.duration
comm_node.kernel_ranges.append((ts, ts + dur))
comm_node.total_time = dur
self.communication_data[op_node.external_id] = comm_node
if event.name == "DataParallel.forward":
self.use_dp = True
if event.name == "DistributedDataParallel.forward":
self.use_ddp = True
tid2list[int(tid)].append(op_node)