in analogy_generation.py [0:0]
def cluster_feats(filehandle, base_classes, cachefile, n_clusters=100):
if os.path.isfile(cachefile):
with open(cachefile, 'rb') as f:
centroids = pickle.load(f)
else:
centroids = []
all_labels = filehandle['all_labels'][...]
all_feats = filehandle['all_feats']
count = filehandle['count'][0]
for j, i in enumerate(base_classes):
print('Clustering class {:d}:{:d}'.format(j,i))
idx = np.where(all_labels==i)[0]
idx = idx[idx<count]
X = all_feats[idx,:]
# use a reimplementation of torch kmeans for reproducible results
# TODO: Figure out why this is important
centroids_this = torch_kmeans.kmeans(X, n_clusters, 20)
centroids.append(centroids_this)
with open(cachefile, 'wb') as f:
pickle.dump(centroids, f)
return centroids