def visualize_embeddings()

in evaluation_pipeline/evaluation.py [0:0]


def visualize_embeddings(model_name, queries):
    fe = FeatureExtractor(EMBEDDING_MODELS_DICT, model_name=model_name)
    model_name_normalized = model_name.replace("/","_").replace("-","_").replace(".","_")
    embeddings = []
    for query in queries:
        query_embeddings = fe.get_embeddings([query])[0]
        embeddings.append(query_embeddings)

    tsne = TSNE(n_components=2, random_state=42, perplexity = 25)
    reduced_embeddings = tsne.fit_transform(np.array(embeddings))

    # Wrap text to a maximum width
    wrapped_labels = [ "\n".join(wrap(label, width=25)) for label in queries ]  # 15 characters per line


    # Plot with wrapped labels
    fig, ax = plt.subplots(figsize=(10, 7))

    plt.scatter([x[0] for x in reduced_embeddings], [x[1] for x in reduced_embeddings], c='orange', alpha=0.6)

    for i, txt in enumerate(wrapped_labels):
        ax.annotate(txt, (reduced_embeddings[i, 0], reduced_embeddings[i, 1]),fontsize=7,)
    
    ax.set_xlabel('Dimension 1')
    ax.set_ylabel('Dimension 2')
    ax.set_title(f't-SNE Visualization of Embeddings for {model_name}')
    try: 
        wandb.log({"chart": wandb.Image(fig)})
    except:
        pass 
    plt.savefig(f'figs/t_sne_embeddings_{model_name_normalized}.png', format='png', dpi=300, bbox_inches='tight')
    plt.close()