def plot()

in pyhanabi/tools/action_matrix.py [0:0]


def plot(mat, title, num_player, *, fig=None, ax=None, savefig=None):
    if fig is None and ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))
    cax = ax.matshow(mat)
    ax.set_title(title)
    if num_player == 2:
        ax.set_xticks(range(20))
        ax.set_xticklabels(idx2action)
        ax.set_yticks(range(20))
        ax.set_yticklabels(idx2action)
    elif num_player == 3:
        ax.set_xticks(range(30))
        ax.set_xticklabels(idx2action_p3)
        ax.set_yticks(range(30))
        ax.set_yticklabels(idx2action_p3)

    if savefig is not None:
        plt.tight_layout()
        plt.savefig(savefig)