in modules/mlvis-common/src/utils/kmeans.js [111:148]
export function fillEmptyClusters(distances, minCount = 1) {
return tf.tidy(() => {
const nClusters = distances.shape[0];
// mins: the IDs of centroid that's nearest to each data instance, shape = [nInscances]
let mins = tf.argMin(distances, 0).toInt();
let mutatedIds = tf.tensor([]);
let count = 0;
while (count < nClusters) {
const clusterSampleCount = tf.sum(tf.oneHot(mins, nClusters), 0);
const nEmptyClusters = tf.sum(
tf.less(clusterSampleCount, tf.scalar(minCount)),
0
);
// nEmptyClusters is a scalar, but `.dataSync()` returns an array with length = 1
if (nEmptyClusters.dataSync()[0] <= 0) {
return mins;
}
// tackle clusters one by one, starting from the smallest cluster
const clusterToFill = tf.argMin(clusterSampleCount, 0).toInt();
// dataIds that have already been mutated before should not be mutated again
// set values at these indices to Infinity so they will not be picked by `argmin`
const maskedDistances = assign(
// only care about distances to the centroid of `clusterToFill`
tf.squeeze(tf.gather(distances, clusterToFill)),
mutatedIds,
tf.scalar(Infinity)
);
// assign cluster K to data point P if P has shortest distance (among previously un-mutated data) to cluster K's center
const dataIdToMutate = tf.argMin(maskedDistances, 0).reshape([1]);
mutatedIds = tf.concat([mutatedIds, dataIdToMutate]);
mins = assign(mins, dataIdToMutate, clusterToFill.reshape([1]));
count++;
}
return mins;
});
}