def gen_barchart()

in bench/generation/gen_barchart.py [0:0]


def gen_barchart(model_id, title, label, results, dtype):
    dtype_str = "f16" if dtype is torch.float16 else "bf16"
    activations = (dtype_str, "f8")
    weights = ("i4", "i8", "f8")
    series = {}
    reference = round(results[f"W{dtype_str}A{dtype_str}"], 2)
    series[f"Weights {dtype_str}"] = [
        reference,
    ] * len(activations)
    for w in weights:
        name = f"Weights {w}"
        series[name] = []
        for a in activations:
            result = results[f"W{w}A{a}"]
            series[name].append(round(result, 2))
    model_name = model_id.replace("/", "-")
    metric_name = label.replace(" ", "_").replace("(", "_").replace(")", "_")
    save_bar_chart(
        title=title,
        labels=[f"Activations {a}" for a in activations],
        series=series,
        ylabel=label,
        save_path=f"{model_name}_{dtype_str}_{metric_name}.png",
    )