in xformers/benchmarks/benchmark_encoder.py [0:0]
def plot(args, results: List[Dict[str, Any]]):
df = pd.DataFrame(results)
HEADS = args.heads[-1]
AMP = args.pytorch_amp[-1]
EMB = args.embedding_dim[-1]
CAUSAL = args.causal[-1]
BATCH_SIZE = args.batch_size[-1]
ACTIVATION = args.activations[-1]
df_filtered = df[
(df["activation"] == ACTIVATION)
& (df["heads"] == HEADS)
& (df["autocast"] == AMP)
& (df["embed_dim"] == EMB)
& (df["causal"] == CAUSAL)
& (df["batch_size"] == BATCH_SIZE)
]
df_filtered.sort_values(
by=["sequence_length", "max_memory"], ascending=[False, True], inplace=True
)
sns.barplot(
x="sequence_length",
y="max_memory",
hue="attention_name",
data=df_filtered,
palette="Set2",
)
plt.xlabel("Sequence length")
plt.ylabel("Max memory being used")
plt.title("Memory use")
plt.savefig("memory_vs_attention.png")
plt.clf()
df_filtered.sort_values(
by=["sequence_length", "run_time"], ascending=[False, True], inplace=True
)
sns.barplot(
x="sequence_length",
y="run_time",
hue="attention_name",
data=df_filtered,
palette="Set2",
)
plt.xlabel("Sequence length")
plt.ylabel("Average epoch time")
plt.title("Runtime")
plt.savefig("runtime_vs_attention.png")