in moonlight/glyphs/saved_classifier.py [0:0]
def __init__(self,
structure,
saved_model_dir,
num_sections=19,
*args,
**kwargs):
"""Loads a saved classifier model for the OMR engine.
Args:
structure: A `structure.Structure`.
saved_model_dir: Path to the TF saved_model directory to load.
num_sections: Number of vertical positions of patches to extract, centered
on the middle staff line.
*args: Passed through to `SavedConvolutional1DClassifier`.
**kwargs: Passed through to `SavedConvolutional1DClassifier`.
Raises:
ValueError: If the saved model input could not be interpreted as a 3D
array with the patch size.
"""
super(SavedConvolutional1DClassifier, self).__init__(*args, **kwargs)
sess = tf.get_default_session()
graph_def = tf.saved_model.loader.load(
sess, [tf.saved_model.tag_constants.SERVING], saved_model_dir)
signature = None
for key in _SIGNATURE_KEYS:
if key in graph_def.signature_def:
signature = graph_def.signature_def[key]
break
else:
# for/else is only executed if the loop completes without breaking.
raise ValueError('One of the following signatures must be present: %s' %
_SIGNATURE_KEYS)
input_info = signature.inputs['input']
if not (len(input_info.tensor_shape.dim) == 3 and
input_info.tensor_shape.dim[1].size > 0 and
input_info.tensor_shape.dim[2].size > 0):
raise ValueError('Invalid patches input: ' + str(input_info))
patch_height = input_info.tensor_shape.dim[1].size
patch_width = input_info.tensor_shape.dim[2].size
with tf.name_scope('saved_classifier'):
self.staffline_extractor = staffline_extractor.StafflineExtractor(
structure.staff_remover.remove_staves,
structure.staff_detector,
num_sections=num_sections,
target_height=patch_height)
stafflines = self.staffline_extractor.extract_staves()
num_staves = tf.shape(stafflines)[0]
num_sections = tf.shape(stafflines)[1]
staffline_patches = patches.patches_1d(stafflines, patch_width)
staffline_patches_shape = tf.shape(staffline_patches)
patches_per_position = staffline_patches_shape[2]
flat_patches = tf.reshape(staffline_patches, [
num_staves * num_sections * patches_per_position, patch_height,
patch_width
])
# Feed in the flat extracted patches as the classifier input.
predictions_name = signature.outputs[
prediction_keys.PredictionKeys.CLASS_IDS].name
predictions = contrib_graph_editor.graph_replace(
sess.graph.get_tensor_by_name(predictions_name), {
sess.graph.get_tensor_by_name(signature.inputs['input'].name):
flat_patches
})
# Reshape to the original patches shape.
predictions = tf.reshape(predictions, staffline_patches_shape[:3])
# Pad the output. We take only the valid patches, but we want to shift all
# of the predictions so that a patch at index i on the x-axis is centered
# on column i. This determines the x coordinates of the glyphs.
width = tf.shape(stafflines)[-1]
predictions_width = tf.shape(predictions)[-1]
pad_before = (width - predictions_width) // 2
pad_shape_before = tf.concat([staffline_patches_shape[:2], [pad_before]],
axis=0)
pad_shape_after = tf.concat([
staffline_patches_shape[:2], [width - predictions_width - pad_before]
],
axis=0)
self.output = tf.concat(
[
# NONE has value 1.
tf.ones(pad_shape_before, tf.int64),
tf.to_int64(predictions),
tf.ones(pad_shape_after, tf.int64),
],
axis=-1)
# run_min_length can be set on the saved model to tweak its behavior, but
# should be overridden by the keyword argument.
if 'run_min_length' not in kwargs:
try:
# Try to read the run min length from the saved model. This is tweaked
# on a per-model basis.
run_min_length_t = sess.graph.get_tensor_by_name(
_RUN_MIN_LENGTH_CONSTANT_NAME)
run_min_length = contrib_util.constant_value(run_min_length_t)
# Implicit comparison is invalid on a NumPy array.
# pylint: disable=g-explicit-bool-comparison
if run_min_length is None or run_min_length.shape != ():
raise ValueError('Bad run_min_length: {}'.format(run_min_length))
# Overwrite the property after the Convolutional1DGlyphClassifier
# constructor completes.
self.run_min_length = int(run_min_length)
except KeyError:
pass # No run_min_length tensor in the saved model.