in evaluation_pipeline/evaluation.py [0:0]
def visualize_embeddings(model_name, queries):
fe = FeatureExtractor(EMBEDDING_MODELS_DICT, model_name=model_name)
model_name_normalized = model_name.replace("/","_").replace("-","_").replace(".","_")
embeddings = []
for query in queries:
query_embeddings = fe.get_embeddings([query])[0]
embeddings.append(query_embeddings)
tsne = TSNE(n_components=2, random_state=42, perplexity = 25)
reduced_embeddings = tsne.fit_transform(np.array(embeddings))
# Wrap text to a maximum width
wrapped_labels = [ "\n".join(wrap(label, width=25)) for label in queries ] # 15 characters per line
# Plot with wrapped labels
fig, ax = plt.subplots(figsize=(10, 7))
plt.scatter([x[0] for x in reduced_embeddings], [x[1] for x in reduced_embeddings], c='orange', alpha=0.6)
for i, txt in enumerate(wrapped_labels):
ax.annotate(txt, (reduced_embeddings[i, 0], reduced_embeddings[i, 1]),fontsize=7,)
ax.set_xlabel('Dimension 1')
ax.set_ylabel('Dimension 2')
ax.set_title(f't-SNE Visualization of Embeddings for {model_name}')
try:
wandb.log({"chart": wandb.Image(fig)})
except:
pass
plt.savefig(f'figs/t_sne_embeddings_{model_name_normalized}.png', format='png', dpi=300, bbox_inches='tight')
plt.close()