optimum_benchmark/profilers/fx_profiler.py (36 lines of code) (raw):

import time from logging import getLogger from typing import Any, List, Tuple import torch from torch.fx import Interpreter from torch.fx.graph_module import GraphModule from torch.fx.node import Node LOGGER = getLogger("fx_profiler") class FXProfilingWrapper(Interpreter): def __init__(self, module: GraphModule): super().__init__(module) self.profiling_records: List[Tuple[str, str, float]] = [] def run(self, *args) -> Any: return super().run(*args) def run_node(self, node: Node) -> Any: if self.module.device.type == "cuda": start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record(stream=torch.cuda.current_stream()) return_val = super().run_node(node) end.record(stream=torch.cuda.current_stream()) torch.cuda.synchronize() node_runtime = start.elapsed_time(end) / 1e3 else: start = time.perf_counter_ns() return_val = super().run_node(node) end = time.perf_counter_ns() node_runtime = (end - start) / 1e9 LOGGER.debug(f"Node {node.name} took {node_runtime:.2e} seconds") self.profiling_records.append((node.name, node.op, node_runtime)) return return_val def __call__(self, **kwargs) -> Any: args = kwargs.values() return super().run(*args) def get_profiling_records(self) -> List[Tuple[str, str, float]]: return self.profiling_records