optimum_benchmark/profilers/ort_profiler.py (48 lines of code) (raw):
import json
from logging import getLogger
from typing import List, Tuple
import pandas as pd
from optimum.onnxruntime import ORTModel
LOGGER = getLogger("ort_profiler")
class ORTProfilingWrapper:
def __init__(self, module: ORTModel):
self.module = module
self.profiling_records: List[Tuple[str, str, float]] = []
def __call__(self, *args, **kwargs):
return self.module(*args, **kwargs)
def get_profiling_records(self) -> List[Tuple[str, str, float]]:
profiling_json = self.module.model.end_profiling() # type: ignore
with open(profiling_json) as file_obj:
profiling_data = json.load(file_obj)
if isinstance(profiling_data, dict):
profiling_data = profiling_data["traceEvents"]
profiling_records = extract_last_run_records(profiling_data)
return normalize_records(profiling_records)
def normalize_records(data) -> List[Tuple[str, str, float]]:
records = []
for item in data:
cat = item.get("cat")
if cat is None:
continue
dur = item.get("dur")
if dur is None:
continue
arg = item.get("args")
if arg is None:
continue
op_name = arg.get("op_name")
name = item["name"]
if cat != "Kernel" and not name.endswith("kernel_time"):
continue
if cat in ["Kernel", "Node"]:
LOGGER.debug(f"Kernel/Node {name} took {dur / 1e6:.2e} seconds")
records.append((name.replace("_kernel_time", ""), op_name, dur / 1e6))
return records
def extract_last_run_records(data):
# Here we assume that the traces are properly ordered, so we can simplify the splitting logic.
return (
pd.DataFrame(data)[["name", "cat", "dur", "args"]]
.groupby("name")
.last() # not sure if this is the right way to do it
.reset_index()
.to_dict(orient="records")
)