def plot()

in xformers/benchmarks/benchmark_encoder.py [0:0]


def plot(args, results: List[Dict[str, Any]]):
    df = pd.DataFrame(results)
    HEADS = args.heads[-1]
    AMP = args.pytorch_amp[-1]
    EMB = args.embedding_dim[-1]
    CAUSAL = args.causal[-1]
    BATCH_SIZE = args.batch_size[-1]
    ACTIVATION = args.activations[-1]

    df_filtered = df[
        (df["activation"] == ACTIVATION)
        & (df["heads"] == HEADS)
        & (df["autocast"] == AMP)
        & (df["embed_dim"] == EMB)
        & (df["causal"] == CAUSAL)
        & (df["batch_size"] == BATCH_SIZE)
    ]

    df_filtered.sort_values(
        by=["sequence_length", "max_memory"], ascending=[False, True], inplace=True
    )
    sns.barplot(
        x="sequence_length",
        y="max_memory",
        hue="attention_name",
        data=df_filtered,
        palette="Set2",
    )
    plt.xlabel("Sequence length")
    plt.ylabel("Max memory being used")
    plt.title("Memory use")
    plt.savefig("memory_vs_attention.png")
    plt.clf()

    df_filtered.sort_values(
        by=["sequence_length", "run_time"], ascending=[False, True], inplace=True
    )
    sns.barplot(
        x="sequence_length",
        y="run_time",
        hue="attention_name",
        data=df_filtered,
        palette="Set2",
    )
    plt.xlabel("Sequence length")
    plt.ylabel("Average epoch time")
    plt.title("Runtime")
    plt.savefig("runtime_vs_attention.png")