in optimum/runs_base.py [0:0]
def __init__(self, run_config: dict):
"""Initialize the Run class holding methods to perform inference and evaluation given a config.
A run compares a transformers model and an optimized model on latency/throughput, model size, and provided metrics.
Args:
run_config (dict): Parameters to use for the run. See [`~utils.runs.RunConfig`] for the expected keys.
"""
RunConfig(**run_config) # validate the data (useful if used as standalone)
self.task = run_config["task"]
if run_config["quantization_approach"] == "static":
self.static_quantization = True
else:
self.static_quantization = False
search_space = {"batch_size": run_config["batch_sizes"], "input_length": run_config["input_lengths"]}
self.study = optuna.create_study(
directions=["maximize", "minimize"],
sampler=optuna.samplers.GridSampler(search_space),
)
cpu_info = subprocess.check_output([cpu_info_command()], shell=True).decode("utf-8")
optimum_hash = None
if "dev" in optimum_version.__version__:
optimum_hash = subprocess.check_output(
"git ls-remote https://github.com/huggingface/optimum.git HEAD | awk '{ print $1}'", shell=True
)
optimum_hash = optimum_hash.decode("utf-8").strip("\n")
self.return_body = {
"model_name_or_path": run_config["model_name_or_path"],
"task": self.task,
"task_args": run_config["task_args"],
"dataset": run_config["dataset"],
"quantization_approach": run_config["quantization_approach"],
"operators_to_quantize": run_config["operators_to_quantize"],
"node_exclusion": run_config["node_exclusion"],
"aware_training": run_config["aware_training"],
"per_channel": run_config["per_channel"],
"calibration": run_config["calibration"],
"framework": run_config["framework"],
"framework_args": run_config["framework_args"],
"hardware": cpu_info, # is this ok?
"versions": {
"transformers": transformers.__version__,
"optimum": optimum_version.__version__,
"optimum_hash": optimum_hash,
},
"evaluation": {
"time": [],
"others": {"baseline": {}, "optimized": {}},
},
"max_eval_samples": run_config["max_eval_samples"],
"time_benchmark_args": run_config["time_benchmark_args"],
}