def plot_embedding()

in decoder.py [0:0]


def plot_embedding(x, labels=None, labels_name='labels', labels_order=None, 
					   file_name=None, coldict=None,
					   d1=6.5, d2=6.0, fs=4, ms=20,
					   col_palette=plt.get_cmap("tab10"), bbox=(1.3, 0.7)):    

	idx = np.random.permutation(len(x))
	df = pd.DataFrame(x[idx, :], columns=['x1', 'x2'])
	
	fig = plt.figure(figsize=(d1, d2))
	ax = plt.gca()

	if not (labels is None):
		df[labels_name] = labels[idx]
		if labels_order is None:
			labels_order = np.unique(labels)        
		if coldict is None:
			coldict = dict(zip(labels_order, col_palette[:len(labels)]))
		sns.scatterplot(x="x1", y="x2", hue=labels_name, 
						hue_order=labels_order,
						palette=coldict,
						alpha=1.0, edgecolor="none",
						data=df, ax=ax, s=ms)
		ax.legend(fontsize=fs, loc='outside', bbox_to_anchor=bbox)
			
	else:
		sns.scatterplot(x="x1", y="x2",
						data=df, ax=ax2, s=ms)
	fig.tight_layout()
	ax.axis('off')
	# ax.axis('equal') 

	if file_name:
		plt.savefig(file_name + '.png', format='png', dpi=300)