in bench/generation/gen_barchart.py [0:0]
def gen_barchart(model_id, title, label, results, dtype):
dtype_str = "f16" if dtype is torch.float16 else "bf16"
activations = (dtype_str, "f8")
weights = ("i4", "i8", "f8")
series = {}
reference = round(results[f"W{dtype_str}A{dtype_str}"], 2)
series[f"Weights {dtype_str}"] = [
reference,
] * len(activations)
for w in weights:
name = f"Weights {w}"
series[name] = []
for a in activations:
result = results[f"W{w}A{a}"]
series[name].append(round(result, 2))
model_name = model_id.replace("/", "-")
metric_name = label.replace(" ", "_").replace("(", "_").replace(")", "_")
save_bar_chart(
title=title,
labels=[f"Activations {a}" for a in activations],
series=series,
ylabel=label,
save_path=f"{model_name}_{dtype_str}_{metric_name}.png",
)