def __init__()

in moonlight/glyphs/saved_classifier.py [0:0]


  def __init__(self,
               structure,
               saved_model_dir,
               num_sections=19,
               *args,
               **kwargs):
    """Loads a saved classifier model for the OMR engine.

    Args:
      structure: A `structure.Structure`.
      saved_model_dir: Path to the TF saved_model directory to load.
      num_sections: Number of vertical positions of patches to extract, centered
        on the middle staff line.
      *args: Passed through to `SavedConvolutional1DClassifier`.
      **kwargs: Passed through to `SavedConvolutional1DClassifier`.

    Raises:
      ValueError: If the saved model input could not be interpreted as a 3D
        array with the patch size.
    """
    super(SavedConvolutional1DClassifier, self).__init__(*args, **kwargs)
    sess = tf.get_default_session()
    graph_def = tf.saved_model.loader.load(
        sess, [tf.saved_model.tag_constants.SERVING], saved_model_dir)

    signature = None
    for key in _SIGNATURE_KEYS:
      if key in graph_def.signature_def:
        signature = graph_def.signature_def[key]
        break
    else:
      # for/else is only executed if the loop completes without breaking.
      raise ValueError('One of the following signatures must be present: %s' %
                       _SIGNATURE_KEYS)

    input_info = signature.inputs['input']
    if not (len(input_info.tensor_shape.dim) == 3 and
            input_info.tensor_shape.dim[1].size > 0 and
            input_info.tensor_shape.dim[2].size > 0):
      raise ValueError('Invalid patches input: ' + str(input_info))
    patch_height = input_info.tensor_shape.dim[1].size
    patch_width = input_info.tensor_shape.dim[2].size

    with tf.name_scope('saved_classifier'):
      self.staffline_extractor = staffline_extractor.StafflineExtractor(
          structure.staff_remover.remove_staves,
          structure.staff_detector,
          num_sections=num_sections,
          target_height=patch_height)
      stafflines = self.staffline_extractor.extract_staves()
      num_staves = tf.shape(stafflines)[0]
      num_sections = tf.shape(stafflines)[1]
      staffline_patches = patches.patches_1d(stafflines, patch_width)
      staffline_patches_shape = tf.shape(staffline_patches)
      patches_per_position = staffline_patches_shape[2]
      flat_patches = tf.reshape(staffline_patches, [
          num_staves * num_sections * patches_per_position, patch_height,
          patch_width
      ])

      # Feed in the flat extracted patches as the classifier input.
      predictions_name = signature.outputs[
          prediction_keys.PredictionKeys.CLASS_IDS].name
      predictions = contrib_graph_editor.graph_replace(
          sess.graph.get_tensor_by_name(predictions_name), {
              sess.graph.get_tensor_by_name(signature.inputs['input'].name):
                  flat_patches
          })
      # Reshape to the original patches shape.
      predictions = tf.reshape(predictions, staffline_patches_shape[:3])

      # Pad the output. We take only the valid patches, but we want to shift all
      # of the predictions so that a patch at index i on the x-axis is centered
      # on column i. This determines the x coordinates of the glyphs.
      width = tf.shape(stafflines)[-1]
      predictions_width = tf.shape(predictions)[-1]
      pad_before = (width - predictions_width) // 2
      pad_shape_before = tf.concat([staffline_patches_shape[:2], [pad_before]],
                                   axis=0)
      pad_shape_after = tf.concat([
          staffline_patches_shape[:2], [width - predictions_width - pad_before]
      ],
                                  axis=0)
      self.output = tf.concat(
          [
              # NONE has value 1.
              tf.ones(pad_shape_before, tf.int64),
              tf.to_int64(predictions),
              tf.ones(pad_shape_after, tf.int64),
          ],
          axis=-1)

    # run_min_length can be set on the saved model to tweak its behavior, but
    # should be overridden by the keyword argument.
    if 'run_min_length' not in kwargs:
      try:
        # Try to read the run min length from the saved model. This is tweaked
        # on a per-model basis.
        run_min_length_t = sess.graph.get_tensor_by_name(
            _RUN_MIN_LENGTH_CONSTANT_NAME)
        run_min_length = contrib_util.constant_value(run_min_length_t)
        # Implicit comparison is invalid on a NumPy array.
        # pylint: disable=g-explicit-bool-comparison
        if run_min_length is None or run_min_length.shape != ():
          raise ValueError('Bad run_min_length: {}'.format(run_min_length))
        # Overwrite the property after the Convolutional1DGlyphClassifier
        # constructor completes.
        self.run_min_length = int(run_min_length)
      except KeyError:
        pass  # No run_min_length tensor in the saved model.