in visualize.py [0:0]
def plotPoincareDisc(x,
label_names=None,
file_name=None,
title_name=None,
idx_zoom=None,
show=False,
d1=12,
d2=6,
fs=11,
ms=4,
col_palette=None,
color_dict=None):
if col_palette is None:
col_palette = colors_palette
# col_palette = plt.get_cmap("tab10")
df = pd.DataFrame(dict(x=x[0], y=x[1], label=label_names))
groups = df.groupby('label')
fig = plt.figure(figsize=(d1, d2), dpi=300)
circle = plt.Circle((0, 0), radius=1, fc='none', color='black')
plt.subplot(1, 2, 1)
plt.gca().add_patch(circle)
plt.plot(0, 0, 'x', c=(0, 0, 0), ms=ms)
plt.title(title_name, fontsize=fs)
if color_dict is None:
j = 0
color_dict = {}
for name, group in groups:
color_dict[name] = col_palette[j]
j += 1
marker = 'o'
for name, group in groups:
plt.plot(group.x, group.y, marker=marker, markerfacecolor='none',
c=color_dict[name], linestyle='', ms=ms, label=name)
plt.plot(0, 0, 'x', c=(1, 1, 1), ms=ms)
plt.axis('off')
plt.axis('equal')
# plt.legend(numpoints=1, loc='center left',
# bbox_to_anchor=(1, 0.5), fontsize=fs)
labels_list = np.unique(label_names)
for l in labels_list:
# i = np.random.choice(np.where(labels == l)[0])
ix_l = np.where(label_names == l)[0]
c1 = np.median(x[0, ix_l])
c2 = np.median(x[1, ix_l])
plt.text(c1, c2, l, fontsize=fs)
#
if idx_zoom is None:
xl = np.array(linear_scale(x))
xl[np.isnan(xl)] = 0
df = pd.DataFrame(dict(x=xl[0], y=xl[1], label=label_names))
groups = df.groupby('label')
else:
xl = np.array(linear_scale(x[:, idx_zoom]))
xl[np.isnan(xl)] = 0
df = pd.DataFrame(dict(x=xl[0], y=xl[1], label=label_names[idx_zoom]))
groups = df.groupby('label')
circle = plt.Circle((0, 0), radius=1, fc='none',
color='black', linestyle=':')
plt.subplot(1, 2, 2)
plt.gca().add_patch(circle)
plt.plot(0, 0, 'x', c=(0, 0, 0), ms=ms)
plt.title('zoom in', fontsize=fs)
for name, group in groups:
plt.plot(group.x, group.y, marker=marker, markerfacecolor='none',
c=color_dict[name], linestyle='', ms=ms, label=name)
plt.plot(0, 0, 'x', c=(1, 1, 1), ms=6)
plt.axis('off')
plt.axis('equal')
plt.legend(numpoints=1, loc='center left',
bbox_to_anchor=(1, 0.5), fontsize=fs)
plt.tight_layout()
if file_name:
plt.savefig(file_name + '.png', format='png')
if show:
plt.show()
plt.close(fig)
return color_dict