in moonlight/glyphs/knn.py [0:0]
def __init__(self, corpus_file, staffline_extractor, **kwargs):
"""Build a 1-nearest-neighbor classifier with labeled patches.
Args:
corpus_file: Path to the TFRecords of Examples with patch (cluster) values
in the "patch" feature, and the glyph label in the "label" feature.
staffline_extractor: The staffline extractor.
**kwargs: Passed through to `Convolutional1DGlyphClassifier`.
"""
super(NearestNeighborGlyphClassifier, self).__init__(**kwargs)
patch_height, patch_width = corpus.get_patch_shape(corpus_file)
centroids, labels = corpus.parse_corpus(corpus_file, patch_height,
patch_width)
centroids_shape = tf.shape(centroids)
flattened_centroids = tf.reshape(
centroids,
[centroids_shape[0], centroids_shape[1] * centroids_shape[2]])
self.staffline_extractor = staffline_extractor
stafflines = staffline_extractor.extract_staves()
# Collapse the stafflines per stave.
width = tf.shape(stafflines)[-1]
# Shape (num_staves, num_stafflines, num_patches, height, patch_width).
staffline_patches = patches.patches_1d(stafflines, patch_width)
staffline_patches_shape = tf.shape(staffline_patches)
flattened_patches = tf.reshape(staffline_patches, [
staffline_patches_shape[0] * staffline_patches_shape[1] *
staffline_patches_shape[2],
staffline_patches_shape[3] * staffline_patches_shape[4]
])
distance_matrix = _squared_euclidean_distance_matrix(
flattened_patches, flattened_centroids)
# Take the k centroids with the lowest distance to each patch. Wrap the k
# constant in a tf.identity, which tests can use to feed in another value.
k_value = tf.identity(tf.constant(K_NEAREST_VALUE), name='k_nearest_value')
nearest_centroid_inds = tf.nn.top_k(-distance_matrix, k=k_value)[1]
# Get the label corresponding to each nearby centroids, and reshape the
# labels back to the original shape.
nearest_labels = tf.reshape(
tf.gather(labels, tf.reshape(nearest_centroid_inds, [-1])),
tf.shape(nearest_centroid_inds))
# Make a histogram of counts for each glyph type in the nearest centroids,
# for each row (patch).
bins = tf.map_fn(lambda row: tf.bincount(row, minlength=NUM_GLYPHS),
tf.to_int32(nearest_labels))
# Take the argmax of the histogram to get the top prediction. Discard glyph
# type 1 (NONE) for now.
mode_out_of_k = tf.argmax(
bins[:, musicscore_pb2.Glyph.NONE + 1:], axis=1) + 2
# Force predictions to NONE only if all k nearby centroids were NONE.
# Otherwise, the non-NONE nearby centroids will contribute to the
# prediction.
mode_out_of_k = tf.where(
tf.equal(bins[:, musicscore_pb2.Glyph.NONE], k_value),
tf.fill(
tf.shape(mode_out_of_k), tf.to_int64(musicscore_pb2.Glyph.NONE)),
mode_out_of_k)
predictions = tf.reshape(mode_out_of_k, staffline_patches_shape[:3])
# Pad the output.
predictions_width = tf.shape(predictions)[-1]
pad_before = (width - predictions_width) // 2
pad_shape_before = tf.concat([staffline_patches_shape[:2], [pad_before]],
axis=0)
pad_shape_after = tf.concat(
[staffline_patches_shape[:2], [width - predictions_width - pad_before]],
axis=0)
self.output = tf.concat(
[
# NONE has value 1.
tf.ones(pad_shape_before, tf.int64),
predictions,
tf.ones(pad_shape_after, tf.int64),
],
axis=-1)