in tb_plugin/torch_tb_profiler/profiler/memory_parser.py [0:0]
def update_node(self, records_by_tid):
tree_height = 0
for tid, records in records_by_tid.items():
if not records:
continue
# each item is (parent_node, child_index) that it is visiting.
node_stack = []
record_index = 0
current_node = self.tid2tree.get(tid)
child_index = 0
if current_node:
self.processed_node[current_node] += 1
while record_index < len(records):
'''In the loop, one pass will process one record. The basic logic is:
It will search from the node that last visited since both the records and tree is ordered already
1. it current node contains the records, then find the exactly child which just embrace it.
2. otherwise, find the parent node and set the child_index, so that the parent node could continue from previous visited node.
3. if there is not any node contains the records, then all remaining records will be ignored.
'''
record = records[record_index]
if len(node_stack) > tree_height:
tree_height = len(node_stack)
if current_node is None:
# 3. Ignore all remaining records.
logger.debug("could not find the node for tid %d, timestamp: %d, record index: %d, total records: %d" % (
record.tid, record.ts, record_index, len(records)))
self.staled_records.append(records[record_index])
record_index += 1
continue
if record.ts < current_node.start_time:
# this should only happens for root node.
logger.debug("record timestamp %d is less that the start time of %s" %
(record.ts, current_node.name))
# This record has no chance to be appended to following tree node.
self.staled_records.append(record)
record_index += 1
continue
elif record.ts >= current_node.end_time:
# 2. pop parent node and update the child_index accordingly.
if len(node_stack) > 0:
current_node, child_index = node_stack.pop()
child_index += 1
else:
# if there is not item in stack, set it to None
current_node = None
continue
# 1. find the real node embrace the record.
# Find the node which contains the records from top to downmost.
while child_index < len(current_node.children):
if record.ts < current_node.children[child_index].start_time:
# if current record timestamp is less than the current child's startime,
# we will break the search and keep the child_index not change. So that next time
# we can continue from here.
# there is no any child contains the record.timestamp
# child_find is False at this case.
break
elif record.ts >= current_node.children[child_index].end_time:
# if the record timestamp is greater than the children end time, increment to next child
# until find one contains the record
child_index += 1
else:
# current children contains the record
self.processed_node[current_node.children[child_index]] += 1
# push child index which will be visited, then continue the loop
node_stack.append((current_node, child_index))
current_node = current_node.children[child_index]
child_index = 0
# the current_node is the one contains the record at this moment.
if is_operator_node(current_node):
current_node.add_memory_record(record)
# NOTE: only allocation record can be associated with op. Because deallocation happens at the end
# of a tensor's lifetime which is not deterministic.
if record.is_allocation:
record.op_name = current_node.name
if len(node_stack) > 0:
record.parent_op_name = node_stack[-1][0].name
self.processed_records.append(record)
else:
self.staled_records.append(record)
# the record is processed
record_index += 1
# show summary information
if len(self.staled_records) > 0 and self.record_length(records_by_tid) > 0:
logger.debug("{} memory records are skipped in total {} memory records and only {} get processed".format(
len(self.staled_records), self.record_length(records_by_tid), len(self.processed_records)))
if tree_height > 0:
logger.debug("max tree height is {}".format(tree_height))