in moonlight/training/clustering/staffline_patches_kmeans_pipeline.py [0:0]
def main(_):
tf.logging.info('Building the pipeline...')
records_dir = tempfile.mkdtemp(prefix='staffline_kmeans')
try:
patch_file_prefix = os.path.join(records_dir, 'patches')
with pipeline_flags.create_pipeline() as pipeline:
filenames = file_io.get_matching_files(FLAGS.music_pattern)
assert filenames, 'Must have matched some filenames'
if 0 < FLAGS.num_pages < len(filenames):
filenames = random.sample(filenames, FLAGS.num_pages)
filenames = pipeline | beam.transforms.Create(filenames)
patches = filenames | beam.ParDo(
staffline_patches_dofn.StafflinePatchesDoFn(
patch_height=FLAGS.patch_height,
patch_width=FLAGS.patch_width,
num_stafflines=FLAGS.num_stafflines,
timeout_ms=FLAGS.timeout_ms,
max_patches_per_page=FLAGS.max_patches_per_page))
if FLAGS.num_outputs:
patches |= combiners.Sample.FixedSizeGlobally(FLAGS.num_outputs)
patches |= beam.io.WriteToTFRecord(
patch_file_prefix, beam.coders.ProtoCoder(tf.train.Example))
tf.logging.info('Running the pipeline...')
tf.logging.info('Running k-means...')
patch_files = file_io.get_matching_files(patch_file_prefix + '*')
clusters = train_kmeans(patch_files, FLAGS.kmeans_num_clusters,
FLAGS.kmeans_batch_size, FLAGS.kmeans_num_steps)
tf.logging.info('Writing the centroids...')
with tf_record.TFRecordWriter(FLAGS.output_path) as writer:
for cluster in clusters:
example = tf.train.Example()
example.features.feature['features'].float_list.value.extend(cluster)
example.features.feature['height'].int64_list.value.append(
FLAGS.patch_height)
example.features.feature['width'].int64_list.value.append(
FLAGS.patch_width)
writer.write(example.SerializeToString())
tf.logging.info('Done!')
finally:
shutil.rmtree(records_dir)