in log_analyzer/plot.py [0:0]
def log_boxplot(detailed_comm_info: Dict):
MAX_ITEMS, COLS = 5, 2
comm_type2msg_size2time_cost = {}
for comm_type, comm_group, msg_size in sorted(detailed_comm_info.keys()):
if (comm_type, comm_group) not in comm_type2msg_size2time_cost:
comm_type2msg_size2time_cost[(comm_type, comm_group)] = {}
elasped_time = np.array(
detailed_comm_info[(comm_type, comm_group, msg_size)]["_elapsed_time"]
)
comm_type2msg_size2time_cost[(comm_type, comm_group)][
msg_size
] = elasped_time # [elasped_time < 3000]
fig_num = sum(
[
(len(comm_info.keys()) + MAX_ITEMS - 1) // MAX_ITEMS
for comm_info in comm_type2msg_size2time_cost.values()
]
)
fig_rows, fig_idx = (fig_num + COLS - 1) // COLS, 0
fig, axes = plt.subplots(nrows=fig_rows, ncols=COLS, figsize=(8, 6))
fig.tight_layout()
fig.suptitle("for deepspeed Zero3 llama 13B")
for (comm_type, comm_group), comm_info in comm_type2msg_size2time_cost.items():
values, labels = list(comm_info.values()), [
convert_size_to_msg(msg) for msg in comm_info.keys()
]
for j in range(0, len(values), MAX_ITEMS):
ax = axes[fig_idx // COLS][fig_idx % COLS]
fig_idx += 1
ax.set_title("%s %s msg info" % (comm_type.value, comm_group.value))
ax.boxplot(
values[j : j + MAX_ITEMS],
labels=labels[j : j + MAX_ITEMS],
flierprops=dict(
marker="o", markerfacecolor="black", markersize=2, linestyle="none"
),
)
for k in range(j, min(j + MAX_ITEMS, len(values))):
ax.text(
x=k - j + 1,
y=np.max(values[k]) * 1.01,
s=len(values[k]),
horizontalalignment="center",
size="x-small",
color="r",
weight="semibold",
)
plt.show()