def plot_poincare_disc()

in visualize.py [0:0]


def plot_poincare_disc(x, labels=None, labels_name='labels', labels_order=None, 
                       file_name=None, coldict=None,
                       d1=19, d2=18.0, fs=11, ms=20, col_palette=plt.get_cmap("tab10"), bbox=(1.3, 0.7)):    

    idx = np.random.permutation(len(x))
    df = pd.DataFrame(x[idx, :], columns=['pm1', 'pm2'])
    
    fig = plt.figure(figsize=(d1, d2))
    ax = plt.gca()
    circle = plt.Circle((0, 0), radius=1,  fc='none', color='black')
    ax.add_patch(circle)
    ax.plot(0, 0, '.', c=(0, 0, 0), ms=4)

    if not (labels is None):
        df[labels_name] = labels[idx]
        if labels_order is None:
            labels_order = np.unique(labels)        
        if coldict is None:
            coldict = dict(zip(labels_order, col_palette[:len(labels)]))
        sns.scatterplot(x="pm1", y="pm2", hue=labels_name, 
                        hue_order=labels_order,
                        palette=coldict,
                        alpha=1.0, edgecolor="none",
                        data=df, ax=ax, s=ms)

        ax.legend(fontsize=fs, loc='best', bbox_to_anchor=bbox)
            
    else:
        sns.scatterplot(x="pm1", y="pm2",
                        data=df, ax=ax2, s=ms)
    fig.tight_layout()
    ax.axis('off')
    ax.axis('equal')  

    labels_list = np.unique(labels)
    for l in labels_list:
#         i = np.random.choice(np.where(labels == l)[0])
        ix_l = np.where(labels == l)[0]
        c1 = np.median(x[ix_l, 0])
        c2 = np.median(x[ix_l, 1])
        ax.text(c1, c2, l, fontsize=fs)


    if file_name:
        plt.savefig(file_name + '.png', format='png')

    plt.close(fig)