benchmarking/switchback/make_plot_with_jsonl.py (126 lines of code) (raw):

import matplotlib.gridspec as gridspec import matplotlib.pyplot as plt import pandas as pd cmap = plt.get_cmap("cool") if __name__ == "__main__": fig = plt.figure(tight_layout=True, figsize=(12, 3.5)) gs = gridspec.GridSpec(1, 2) dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096] batch_size_for_plot1 = 32768 batch_sizes_for_plot2 = [2**14, 2**15, 2**16, 2**17] dims_to_xtick = [1024, 2048, 4096] logscale_plot1 = True ax = fig.add_subplot(gs[0, 0]) # TODO: change this to what you want. rdf = pd.read_json("speed_benchmark/info_a100_py2.jsonl", lines=True) df = rdf[rdf.batch_size == batch_size_for_plot1] # first plot the time occupied by different operations for k, marker, ls, color, name in [ ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (sum of parts)"), ( "x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd", "o", "-", "C4", "SwitchBack int8 (sum of parts)", ), ("standard_fwd", "^", "--", "C2", "Matmul XW (standard)"), ("standard_gw", "^", "-.", "C2", "Matmul GW (standard)"), ("standard_gx", "^", ":", "gray", "Matmul GX (both)"), ("global_fwd", "^", "--", "C4", "Int8 Matmul XW (switchback)"), ("global_bwd", "^", "-.", "C4", "Int8 Matmul GW (switchback)"), ("x_quantize_rowwise", "P", "--", "C4", "Quantize rowwise X (switchback)"), ("g_quantize_rowwise", "P", "-.", "C4", "Quantize rowwise G (switchback)"), ("w_quantize_global", ".", "--", "C4", "Quantize global W (switchback)"), ("w_quantize_global_transpose", ".", "-.", "C4", "Quantize global and\ntranspose W (switchback)"), ]: xs = [] ys = [] for embed_dim in dims_to_consider: # average over dim -> 4*dim and 4*dim -> dim df_ = df[df.dim_in == embed_dim] df_ = df_[df_.dim_out == embed_dim * 4] xs.append(embed_dim) y_ = 0 for k_ in k.split("+"): y_ += df_[k_].values[0] df_ = df[df.dim_in == embed_dim * 4] df_ = df_[df_.dim_out == embed_dim] for k_ in k.split("+"): y_ += df_[k_].values[0] ys.append(y_ * 0.5) ax.plot( xs, ys, color=color, label=name, marker=marker, markersize=5 if marker == "s" else 5, linestyle=ls, linewidth=2 if "+" in k else 1.0, ) ax.set_xlabel("dim", fontsize=13) ax.set_ylabel("time (ms)", fontsize=13) ax.grid() ax.set_xscale("log") if logscale_plot1: ax.set_yscale("log") ax.tick_params(axis="x", labelsize=11) ax.tick_params(axis="y", labelsize=11) ax.set_xticks(dims_to_xtick) ax.set_xticklabels(dims_to_xtick) ax.set_xticks([], minor=True) leg = ax.legend(loc="upper center", bbox_to_anchor=(-0.64, 1.0), ncol=1, fontsize=10) leg.get_texts()[0].set_fontweight("bold") leg.get_texts()[1].set_fontweight("bold") plt.subplots_adjust(left=0.1) ax.set_title(" Linear layer, batch * sequence length = 32k", fontsize=10, loc="left", y=1.05, pad=-20) ax = fig.add_subplot(gs[0, 1]) # now plot the % speedup for different batch sizes for j, batch_size in enumerate(batch_sizes_for_plot2): all_xs, all_ys = [], [] for k, marker, ls, color, name in [ ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (total time)"), ( "x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd", "o", "-", "C4", "SwitchBack int8 (total time)", ), ]: xs, ys = [], [] df = rdf[rdf.batch_size == batch_size] for embed_dim in dims_to_consider: df_ = df[df.dim_in == embed_dim] df_ = df_[df_.dim_out == embed_dim * 4] xs.append(embed_dim) y_ = 0 for k_ in k.split("+"): y_ += df_[k_].values[0] df_ = df[df.dim_in == embed_dim * 4] df_ = df_[df_.dim_out == embed_dim] for k_ in k.split("+"): y_ += df_[k_].values[0] ys.append(y_ * 0.5) all_xs.append(xs) all_ys.append(ys) color = cmap(j * 0.25) real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))] markers = ["^", "v", "P", "o"] ax.plot( all_xs[0], real_ys, color=color, label=f"batch * sequence length = {batch_size}", marker=markers[j], markersize=5 if marker == "s" else 5, ) ax.legend() ax.set_xlabel("dim", fontsize=13) ax.set_xscale("log") ax.grid() ax.set_ylabel(r"% speedup", fontsize=13) ax.tick_params(axis="x", labelsize=11) ax.tick_params(axis="y", labelsize=11) ax.set_xticks(dims_to_xtick) ax.set_xticklabels(dims_to_xtick) ax.set_xticks([], minor=True) ax.set_title(" Linear layer summary, varying dimensions", fontsize=10, loc="left", y=1.05, pad=-20) plt.savefig("speed_benchmark/plot_with_info.pdf", bbox_inches="tight")