bench/generation/gen_barchart.py (60 lines of code) (raw):

# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import json import matplotlib.pyplot as plt import numpy as np import torch def save_bar_chart(title, labels, ylabel, series, save_path): x = np.arange(len(labels)) # the label locations width = 0.15 # the width of the bars multiplier = 0 fig, ax = plt.subplots(layout="constrained") fig.set_figwidth(10) max_value = 0 for attribute, measurement in series.items(): max_value = max(max_value, max(measurement)) offset = width * multiplier rects = ax.bar(x + offset, measurement, width, label=attribute) ax.bar_label(rects, padding=5) multiplier += 1 # Add some text for labels, title and custom x-axis tick labels, etc. ax.set_ylabel(ylabel) ax.set_title(title) ax.set_xticks(x + width, labels) ax.legend(loc="upper left", ncols=4) ax.set_ylim(0, max_value * 1.2) plt.savefig(save_path) def gen_barchart(model_id, title, label, results, dtype): dtype_str = "f16" if dtype is torch.float16 else "bf16" activations = (dtype_str, "f8") weights = ("i4", "i8", "f8") series = {} reference = round(results[f"W{dtype_str}A{dtype_str}"], 2) series[f"Weights {dtype_str}"] = [ reference, ] * len(activations) for w in weights: name = f"Weights {w}" series[name] = [] for a in activations: result = results[f"W{w}A{a}"] series[name].append(round(result, 2)) model_name = model_id.replace("/", "-") metric_name = label.replace(" ", "_").replace("(", "_").replace(")", "_") save_bar_chart( title=title, labels=[f"Activations {a}" for a in activations], series=series, ylabel=label, save_path=f"{model_name}_{dtype_str}_{metric_name}.png", ) def main(): parser = argparse.ArgumentParser() parser.add_argument("benchmark", type=str, help="A benchmark result file (.json).") parser.add_argument("--title", type=str, required=True, help="The graph title.") parser.add_argument("--label", type=str, required=True, help="The graph vertical label.") args = parser.parse_args() with open(args.benchmark) as f: benchmark = json.load(f) for model_id, results in benchmark.items(): gen_barchart(model_id, args.title, args.label, results) if __name__ == "__main__": main()