def get_memory_statistics()

in tb_plugin/torch_tb_profiler/profiler/memory_parser.py [0:0]


    def get_memory_statistics(self, start_ts=None, end_ts=None):
        metric_length = len(MemoryMetrics)
        self_metric_length = metric_length // 2

        def dict_factory():
            return defaultdict(lambda: [0] * metric_length)

        # traverse outputs
        op_list = []
        # two level keys dictionary
        # first keyed by node, then keyed by device (CPU/GPU0/GPU1/etc.)
        memory_metrics_keyed_by_node = defaultdict(dict_factory)

        def traverse_node_memory(node):
            if start_ts is not None and node.end_time < start_ts:
                return
            if end_ts is not None and node.start_time > end_ts:
                return

            is_op = is_operator_node(node)
            if is_op:
                op_list.append(node)

            if node not in self.processed_node:
                self.unreached_node[tid].append(node)
                # since the node has not been visited for insert memory records, just ignore all childrens
                return
            elif is_op:
                node_memory_metrics = node.get_memory_metrics(start_ts, end_ts)
                for device, metrics in node_memory_metrics.items():
                    # device is name of device like: CPU/GPU0
                    # metrics is an arrary [SelfIncreaseSize, SelfAllocationSize, SelfAllocationCount]
                    for i, value in enumerate(metrics):
                        memory_metrics_keyed_by_node[node][device][i] = value
                        memory_metrics_keyed_by_node[node][device][i + self_metric_length] += value
            else:
                logger.debug("node {}:{} is not operator node, will skip its self metrics processing".format(
                    node.name, node.start_time))

            # recursive the children nodes
            for child in node.children:
                traverse_node_memory(child)
                # sum up the child metrics
                for device, metrics in memory_metrics_keyed_by_node[child].items():
                    for i in range(self_metric_length, metric_length):
                        memory_metrics_keyed_by_node[node][device][i] += metrics[i]

        for tid, root in self.tid2tree.items():
            for child in root.children:
                traverse_node_memory(child)

        # keyed first by device name like CPU/GPU0 etc, then keyed by operator name.
        # the value is array [items indexed by MemoryMetrics]
        memory_metrics_keyed_by_nodename = defaultdict(dict_factory)
        # node: the instance, device_keyed_metrics: dictionary keyed by device name like CPU/GPU0
        for node, device_keyed_metrics in memory_metrics_keyed_by_node.items():
            if not is_operator_node(node):
                # skip the node like Optimizer.step, DataLoader, ProfilerStep#1 etc.
                continue

            for device, metrics in device_keyed_metrics.items():
                for i, metric in enumerate(metrics):
                    memory_metrics_keyed_by_nodename[device][node.name][i] += metric

        # get the op_calls dictionary from module parser result.
        op_calls = defaultdict(int)
        agg_result = aggregate_ops(op_list, [lambda op: op.name])
        for op_name, op_agg in agg_result[0].items():
            op_calls[op_name] += op_agg.calls

        result = defaultdict(defaultdict)
        for device, node_metrics in memory_metrics_keyed_by_nodename.items():
            for node, values in node_metrics.items():
                if any(values):
                    result[device][node] = values + [op_calls[node]]

        return result