in torchbenchmark/models/dlrm/tools/visualize.py [0:0]
def visualize_embeddings_umap(emb_l,
output_dir = "",
max_size = 500000,
umap_metric = "euclidean",
cat_counts = None,
use_max_count = True):
for k in range(0, len(emb_l)):
E = emb_l[k].weight.detach().cpu().numpy()
print("umap", E.shape)
# create histogram of norms
bins = 50
norms = [np.linalg.norm(E[i], ord=2) for i in range(0,E.shape[0])]
# plt.hist(norms, bins = bins)
# plt.title("Cat norm hist var. "+str(k))
hist, bins = np.histogram(norms, bins=bins)
logbins = np.logspace(np.log10(bins[0]),np.log10(bins[-1]),len(bins))
plt.figure(figsize=(8,8))
plt.title("Categorical norms: " + str(k) + " cardinality " + str(len(cat_counts[k])))
plt.hist(norms, bins=logbins)
plt.xscale("log")
# plt.legend()
plt.savefig(output_dir+"/cat-norm-histogram-"+str(k)+".png")
plt.close()
if E.shape[0] < 20:
print("Skipping small embedding")
continue
n_vis = min(max_size, E.shape[0])
min_cnt = 0
# reducer = umap.UMAP(random_state=42, n_neighbors=25, min_dist=0.1)
reducer = umap.UMAP(random_state=42, metric=umap_metric)
if use_max_count is False or n_vis == E.shape[0]:
Y = reducer.fit_transform(E[:n_vis,:])
else:
# select values with couns > 1
done = False
min_cnt = 1
while done == False:
el_cnt = (cat_counts[k] > min_cnt).sum()
if el_cnt <= max_size:
done = True
else:
min_cnt = min_cnt+1
E1= []
for i in range(0, E.shape[0]):
if cat_counts[k][i] > min_cnt:
E1.append(E[i,:])
print("max_count_len", len(E1), "mincount", min_cnt)
Y = reducer.fit_transform(np.array(E1))
n_vis = len(E1)
plt.figure(figsize=(8,8))
linewidth = 0
size = 1
if Y.shape[0] < 2500:
linewidth = 1
size = 5
if cat_counts is None:
plt.scatter(-Y[:,0], -Y[:,1], s=size, marker=".", linewidth=linewidth)
else:
#print(cat_counts[k])
n_disp = min(len(cat_counts[k]), Y.shape[0])
cur_max = math.log(max(cat_counts[k]))
norm_cat_count = [math.log(cat_counts[k][i]+1)/cur_max for i in range(0, len(cat_counts[k]))]
plt.scatter(-Y[0:n_disp,0], -Y[0:n_disp,1], s=size, marker=".", linewidth=linewidth, c=np.array(norm_cat_count)[0:n_disp], cmap="viridis")
plt.colorbar()
plt.title("UMAP: categorical var. " + str(k) + " (" + str(n_vis) + " of " + str(E.shape[0]) + ", min count " + str(min_cnt) + ")")
plt.savefig(output_dir + "/cat-" + str(k) + "-" + str(n_vis) + "-of-" + str(E.shape[0]) + "-umap.png")
plt.close()