export function fillEmptyClusters()

in modules/mlvis-common/src/utils/kmeans.js [111:148]


export function fillEmptyClusters(distances, minCount = 1) {
  return tf.tidy(() => {
    const nClusters = distances.shape[0];
    // mins: the IDs of centroid that's nearest to each data instance, shape = [nInscances]
    let mins = tf.argMin(distances, 0).toInt();
    let mutatedIds = tf.tensor([]);
    let count = 0;

    while (count < nClusters) {
      const clusterSampleCount = tf.sum(tf.oneHot(mins, nClusters), 0);
      const nEmptyClusters = tf.sum(
        tf.less(clusterSampleCount, tf.scalar(minCount)),
        0
      );
      // nEmptyClusters is a scalar, but `.dataSync()` returns an array with length = 1
      if (nEmptyClusters.dataSync()[0] <= 0) {
        return mins;
      }
      // tackle clusters one by one, starting from the smallest cluster
      const clusterToFill = tf.argMin(clusterSampleCount, 0).toInt();
      // dataIds that have already been mutated before should not be mutated again
      // set values at these indices to Infinity so they will not be picked by `argmin`
      const maskedDistances = assign(
        // only care about distances to the centroid of `clusterToFill`
        tf.squeeze(tf.gather(distances, clusterToFill)),
        mutatedIds,
        tf.scalar(Infinity)
      );

      // assign cluster K to data point P if P has shortest distance (among previously un-mutated data) to cluster K's center
      const dataIdToMutate = tf.argMin(maskedDistances, 0).reshape([1]);
      mutatedIds = tf.concat([mutatedIds, dataIdToMutate]);
      mins = assign(mins, dataIdToMutate, clusterToFill.reshape([1]));
      count++;
    }
    return mins;
  });
}