def log_boxplot()

in log_analyzer/plot.py [0:0]


def log_boxplot(detailed_comm_info: Dict):
    MAX_ITEMS, COLS = 5, 2
    comm_type2msg_size2time_cost = {}
    for comm_type, comm_group, msg_size in sorted(detailed_comm_info.keys()):
        if (comm_type, comm_group) not in comm_type2msg_size2time_cost:
            comm_type2msg_size2time_cost[(comm_type, comm_group)] = {}
        elasped_time = np.array(
            detailed_comm_info[(comm_type, comm_group, msg_size)]["_elapsed_time"]
        )
        comm_type2msg_size2time_cost[(comm_type, comm_group)][
            msg_size
        ] = elasped_time  # [elasped_time < 3000]
    fig_num = sum(
        [
            (len(comm_info.keys()) + MAX_ITEMS - 1) // MAX_ITEMS
            for comm_info in comm_type2msg_size2time_cost.values()
        ]
    )

    fig_rows, fig_idx = (fig_num + COLS - 1) // COLS, 0
    fig, axes = plt.subplots(nrows=fig_rows, ncols=COLS, figsize=(8, 6))
    fig.tight_layout()
    fig.suptitle("for deepspeed Zero3 llama 13B")
    for (comm_type, comm_group), comm_info in comm_type2msg_size2time_cost.items():
        values, labels = list(comm_info.values()), [
            convert_size_to_msg(msg) for msg in comm_info.keys()
        ]
        for j in range(0, len(values), MAX_ITEMS):
            ax = axes[fig_idx // COLS][fig_idx % COLS]
            fig_idx += 1
            ax.set_title("%s %s msg info" % (comm_type.value, comm_group.value))
            ax.boxplot(
                values[j : j + MAX_ITEMS],
                labels=labels[j : j + MAX_ITEMS],
                flierprops=dict(
                    marker="o", markerfacecolor="black", markersize=2, linestyle="none"
                ),
            )
            for k in range(j, min(j + MAX_ITEMS, len(values))):
                ax.text(
                    x=k - j + 1,
                    y=np.max(values[k]) * 1.01,
                    s=len(values[k]),
                    horizontalalignment="center",
                    size="x-small",
                    color="r",
                    weight="semibold",
                )
    plt.show()