def main()

in moonlight/training/clustering/staffline_patches_kmeans_pipeline.py [0:0]


def main(_):
  tf.logging.info('Building the pipeline...')
  records_dir = tempfile.mkdtemp(prefix='staffline_kmeans')
  try:
    patch_file_prefix = os.path.join(records_dir, 'patches')
    with pipeline_flags.create_pipeline() as pipeline:
      filenames = file_io.get_matching_files(FLAGS.music_pattern)
      assert filenames, 'Must have matched some filenames'
      if 0 < FLAGS.num_pages < len(filenames):
        filenames = random.sample(filenames, FLAGS.num_pages)
      filenames = pipeline | beam.transforms.Create(filenames)
      patches = filenames | beam.ParDo(
          staffline_patches_dofn.StafflinePatchesDoFn(
              patch_height=FLAGS.patch_height,
              patch_width=FLAGS.patch_width,
              num_stafflines=FLAGS.num_stafflines,
              timeout_ms=FLAGS.timeout_ms,
              max_patches_per_page=FLAGS.max_patches_per_page))
      if FLAGS.num_outputs:
        patches |= combiners.Sample.FixedSizeGlobally(FLAGS.num_outputs)
      patches |= beam.io.WriteToTFRecord(
          patch_file_prefix, beam.coders.ProtoCoder(tf.train.Example))
      tf.logging.info('Running the pipeline...')
    tf.logging.info('Running k-means...')
    patch_files = file_io.get_matching_files(patch_file_prefix + '*')
    clusters = train_kmeans(patch_files, FLAGS.kmeans_num_clusters,
                            FLAGS.kmeans_batch_size, FLAGS.kmeans_num_steps)
    tf.logging.info('Writing the centroids...')
    with tf_record.TFRecordWriter(FLAGS.output_path) as writer:
      for cluster in clusters:
        example = tf.train.Example()
        example.features.feature['features'].float_list.value.extend(cluster)
        example.features.feature['height'].int64_list.value.append(
            FLAGS.patch_height)
        example.features.feature['width'].int64_list.value.append(
            FLAGS.patch_width)
        writer.write(example.SerializeToString())
    tf.logging.info('Done!')
  finally:
    shutil.rmtree(records_dir)