def plotPoincareDisc()

in visualize.py [0:0]


def plotPoincareDisc(x,
                     label_names=None,
                     file_name=None,
                     title_name=None,
                     idx_zoom=None,
                     show=False,
                     d1=12,
                     d2=6,
                     fs=11,
                     ms=4,
                     col_palette=None,
                     color_dict=None):
    if col_palette is None:
        col_palette = colors_palette
        # col_palette = plt.get_cmap("tab10")

    df = pd.DataFrame(dict(x=x[0], y=x[1], label=label_names))
    groups = df.groupby('label')

    fig = plt.figure(figsize=(d1, d2), dpi=300)
    circle = plt.Circle((0, 0), radius=1,  fc='none', color='black')

    plt.subplot(1, 2, 1)
    plt.gca().add_patch(circle)
    plt.plot(0, 0, 'x', c=(0, 0, 0), ms=ms)
    plt.title(title_name, fontsize=fs)

    if color_dict is None:
        j = 0
        color_dict = {}
        for name, group in groups:
            color_dict[name] = col_palette[j]
            j += 1

    marker = 'o'
    for name, group in groups:        
        plt.plot(group.x, group.y, marker=marker, markerfacecolor='none',
                 c=color_dict[name], linestyle='', ms=ms, label=name)
    plt.plot(0, 0, 'x', c=(1, 1, 1), ms=ms)
    plt.axis('off')
    plt.axis('equal')
    # plt.legend(numpoints=1, loc='center left',
    #            bbox_to_anchor=(1, 0.5), fontsize=fs)

    labels_list = np.unique(label_names)

    for l in labels_list:
#         i = np.random.choice(np.where(labels == l)[0])
        ix_l = np.where(label_names == l)[0]
        c1 = np.median(x[0, ix_l])
        c2 = np.median(x[1, ix_l])
        plt.text(c1, c2, l, fontsize=fs)
#
    if idx_zoom is None:
        xl = np.array(linear_scale(x))
        xl[np.isnan(xl)] = 0

        df = pd.DataFrame(dict(x=xl[0], y=xl[1], label=label_names))
        groups = df.groupby('label')
    else:
        xl = np.array(linear_scale(x[:, idx_zoom]))
        xl[np.isnan(xl)] = 0

        df = pd.DataFrame(dict(x=xl[0], y=xl[1], label=label_names[idx_zoom]))
        groups = df.groupby('label')

    circle = plt.Circle((0, 0), radius=1, fc='none',
                        color='black', linestyle=':')
    plt.subplot(1, 2, 2)
    plt.gca().add_patch(circle)
    plt.plot(0, 0, 'x', c=(0, 0, 0), ms=ms)
    plt.title('zoom in', fontsize=fs)

    for name, group in groups:
        plt.plot(group.x, group.y, marker=marker, markerfacecolor='none',
                 c=color_dict[name], linestyle='', ms=ms, label=name)

    plt.plot(0, 0, 'x', c=(1, 1, 1), ms=6)

    plt.axis('off')
    plt.axis('equal')

    plt.legend(numpoints=1, loc='center left',
               bbox_to_anchor=(1, 0.5), fontsize=fs)

    plt.tight_layout()

    if file_name:
        plt.savefig(file_name + '.png', format='png')

    if show:
        plt.show()

    plt.close(fig)

    return color_dict