def visualize_embeddings_umap()

in tools/visualize.py [0:0]


def visualize_embeddings_umap(emb_l, 
                              output_dir    = "",
                              max_size      = 500000, 
                              umap_metric   = "euclidean",
                              cat_counts    = None,
                              use_max_count = True):

    for k in range(0, len(emb_l)):

        E = emb_l[k].weight.detach().cpu().numpy()
        print("umap", E.shape)

        # create histogram of norms
        bins = 50
        norms = [np.linalg.norm(E[i], ord=2) for i in range(0,E.shape[0])]
#        plt.hist(norms, bins = bins)
#        plt.title("Cat norm hist var. "+str(k))
        hist, bins = np.histogram(norms, bins=bins)
        logbins = np.logspace(np.log10(bins[0]),np.log10(bins[-1]),len(bins))

        plt.figure(figsize=(8,8))
        plt.title("Categorical norms: " + str(k) + " cardinality " + str(len(cat_counts[k])))
        plt.hist(norms, bins=logbins)
        plt.xscale("log")
#        plt.legend()
        plt.savefig(output_dir+"/cat-norm-histogram-"+str(k)+".png")
        plt.close()

        if E.shape[0] < 20:
            print("Skipping small embedding")
            continue

        n_vis = min(max_size, E.shape[0])
        min_cnt = 0
        
#        reducer = umap.UMAP(random_state=42, n_neighbors=25, min_dist=0.1)
        reducer = umap.UMAP(random_state=42, metric=umap_metric)
        
        if use_max_count is False or n_vis == E.shape[0]:
            Y = reducer.fit_transform(E[:n_vis,:])
        else:
            
            # select values with couns > 1
            done  = False
            min_cnt = 1
            while done == False:
                el_cnt = (cat_counts[k] > min_cnt).sum()
                if el_cnt <= max_size:
                    done = True
                else:
                    min_cnt = min_cnt+1
           
            E1= []
            for i in range(0, E.shape[0]):
                if cat_counts[k][i] > min_cnt:
                    E1.append(E[i,:])
            
            print("max_count_len", len(E1), "mincount", min_cnt)
            Y = reducer.fit_transform(np.array(E1))

            n_vis = len(E1)

        plt.figure(figsize=(8,8))
        
        linewidth = 0
        size      = 1
        
        if Y.shape[0] < 2500:
            linewidth = 1 
            size      = 5

        if cat_counts is None:
            plt.scatter(-Y[:,0], -Y[:,1], s=size, marker=".", linewidth=linewidth)
        else:
            #print(cat_counts[k])
            n_disp = min(len(cat_counts[k]), Y.shape[0])
            cur_max = math.log(max(cat_counts[k]))
            norm_cat_count = [math.log(cat_counts[k][i]+1)/cur_max for i in range(0, len(cat_counts[k]))]
            plt.scatter(-Y[0:n_disp,0], -Y[0:n_disp,1], s=size, marker=".", linewidth=linewidth, c=np.array(norm_cat_count)[0:n_disp], cmap="viridis")
            plt.colorbar()
            
        plt.title("UMAP: categorical var. " + str(k) + "  (" + str(n_vis) + " of " + str(E.shape[0]) + ", min count " + str(min_cnt) + ")")
        plt.savefig(output_dir + "/cat-" + str(k) + "-" + str(n_vis) + "-of-" + str(E.shape[0]) + "-umap.png")
        plt.close()