def cluster_feats()

in analogy_generation.py [0:0]


def cluster_feats(filehandle, base_classes, cachefile, n_clusters=100):
    if os.path.isfile(cachefile):
        with open(cachefile, 'rb') as f:
            centroids = pickle.load(f)
    else:
        centroids = []
        all_labels = filehandle['all_labels'][...]
        all_feats = filehandle['all_feats']

        count = filehandle['count'][0]
        for j, i in enumerate(base_classes):
            print('Clustering class {:d}:{:d}'.format(j,i))
            idx = np.where(all_labels==i)[0]
            idx = idx[idx<count]
            X = all_feats[idx,:]
            # use a reimplementation of torch kmeans for reproducible results
            # TODO: Figure out why this is important
            centroids_this = torch_kmeans.kmeans(X, n_clusters, 20)
            centroids.append(centroids_this)
        with open(cachefile, 'wb') as f:
            pickle.dump(centroids, f)
    return centroids