in moonlight/training/clustering/staffline_patches_kmeans_pipeline.py [0:0]
def train_kmeans(patch_file_pattern,
num_clusters,
batch_size,
train_steps,
min_eval_frequency=None):
"""Runs TensorFlow K-Means over TFRecords.
Args:
patch_file_pattern: Pattern that matches TFRecord file(s) holding Examples
with image patches.
num_clusters: Number of output clusters.
batch_size: Size of a k-means minibatch.
train_steps: Number of steps for k-means training.
min_eval_frequency: The minimum number of steps between evaluations. Of
course, evaluation does not occur if no new snapshot is available, hence,
this is the minimum. If 0, the evaluation will only happen after
training. If None, defaults to 1. To avoid checking for new checkpoints
too frequent, the interval is further limited to be at least
check_interval_secs between checks. See
third_party/tensorflow/contrib/learn/python/learn/experiment.py for
details.
Returns:
A NumPy array of shape (num_clusters, patch_height * patch_width). The
cluster centers.
"""
def input_fn():
"""The tf.learn input_fn.
Returns:
features, a float32 tensor of shape
(batch_size, patch_height * patch_width).
None for labels (not applicable to k-means).
"""
examples = contrib_learn.read_batch_examples(
patch_file_pattern,
batch_size,
tf.TFRecordReader,
queue_capacity=batch_size * 2)
features = tf.parse_example(
examples, {
'features':
tf.FixedLenFeature(FLAGS.patch_height * FLAGS.patch_width,
tf.float32)
})['features']
return features, None # no labels
def experiment_fn(run_config, unused_hparams):
"""The tf.learn experiment_fn.
Args:
run_config: The run config to be passed to the KMeansClustering.
unused_hparams: Hyperparameters; not applicable.
Returns:
A tf.contrib.learn.Experiment.
"""
kmeans = contrib_learn.KMeansClustering(
num_clusters=num_clusters, config=run_config)
return contrib_learn.Experiment(
estimator=kmeans,
train_steps=train_steps,
train_input_fn=input_fn,
eval_steps=1,
eval_input_fn=input_fn,
min_eval_frequency=min_eval_frequency)
output_dir = tempfile.mkdtemp(prefix='staffline_patches_kmeans')
try:
learn_runner.run(
experiment_fn, run_config=contrib_learn.RunConfig(model_dir=output_dir))
num_features = FLAGS.patch_height * FLAGS.patch_width
clusters_t = tf.Variable(
tf.zeros((num_clusters, num_features)), # Dummy init op
name='clusters')
with tf.Session() as sess:
tf.train.Saver(var_list=[clusters_t]).restore(
sess, os.path.join(output_dir, 'model.ckpt-%d' % train_steps))
return clusters_t.eval()
finally:
shutil.rmtree(output_dir)