in moonlight/glyphs/knn_model.py [0:0]
def knn_kmeans_model(centroids, labels, patches=None):
"""The KNN k-means classifier model.
Args:
centroids: The k-means centroids NumPy array. Shape `(num_centroids,
patch_height, patch_width)`.
labels: The centroid labels NumPy array. Vector with length `num_centroids`.
patches: Optional input tensor for the patches. If None, a placeholder will
be used.
Returns:
The predictions (class ids) tensor determined from the input patches. Vector
with the same length as `patches`.
"""
with tf.name_scope('knn_model'):
centroids = tf.identity(
_to_float(tf.constant(_to_uint8(centroids))), name='centroids')
labels = tf.constant(labels, name='labels')
centroids_shape = tf.shape(centroids)
num_centroids = centroids_shape[0]
patch_height = centroids_shape[1]
patch_width = centroids_shape[2]
flattened_centroids = tf.reshape(
centroids, [num_centroids, patch_height * patch_width],
name='flattened_centroids')
if patches is None:
patches = tf.placeholder(
tf.float32, (None, centroids.shape[1], centroids.shape[2]),
name='patches')
patches_shape = tf.shape(patches)
flattened_patches = tf.reshape(
patches, [patches_shape[0], patches_shape[1] * patches_shape[2]],
name='flattened_patches')
with tf.name_scope('distance_matrix'):
distance_matrix = _squared_euclidean_distance_matrix(
flattened_patches, flattened_centroids)
# Take the k centroids with the lowest distance to each patch. Wrap the k
# constant in a tf.identity, which tests can use to feed in another value.
k_value = tf.identity(tf.constant(K_NEAREST_VALUE), name='k_nearest_value')
nearest_centroid_inds = tf.nn.top_k(-distance_matrix, k=k_value)[1]
# Get the label corresponding to each nearby centroids, and reshape the
# labels back to the original shape.
nearest_labels = tf.reshape(
tf.gather(labels, tf.reshape(nearest_centroid_inds, [-1])),
tf.shape(nearest_centroid_inds),
name='nearest_labels')
# Make a histogram of counts for each glyph type in the nearest centroids,
# for each row (patch).
length = NUM_GLYPHS
bins = tf.map_fn(
lambda row: tf.bincount(row, minlength=length, maxlength=length),
tf.to_int32(nearest_labels),
name='bins')
with tf.name_scope('mode_out_of_k'):
# Take the argmax of the histogram to get the top prediction. Discard
# glyph type 1 (NONE) for now.
mode_out_of_k = tf.argmax(
bins[:, musicscore_pb2.Glyph.NONE + 1:], axis=1) + 2
# Force predictions to NONE only if all k nearby centroids were NONE.
# Otherwise, the non-NONE nearby centroids will contribute to the
# prediction.
mode_out_of_k = tf.where(
tf.equal(bins[:, musicscore_pb2.Glyph.NONE], k_value),
tf.fill(
tf.shape(mode_out_of_k), tf.to_int64(musicscore_pb2.Glyph.NONE)),
mode_out_of_k)
return tf.identity(mode_out_of_k, name='predictions')