in example_zoo/tensorflow/models/mnist/official/mnist/mnist.py [0:0]
def run_mnist(flags_obj):
"""Run MNIST training and eval loop.
Args:
flags_obj: An object containing parsed flag values.
"""
model_helpers.apply_clean(flags_obj)
model_function = model_fn
session_config = tf.ConfigProto(
inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
allow_soft_placement=True)
distribution_strategy = distribution_utils.get_distribution_strategy(
flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)
run_config = tf.estimator.RunConfig(
train_distribute=distribution_strategy, session_config=session_config)
data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
mnist_classifier = tf.estimator.Estimator(
model_fn=model_function,
model_dir=flags_obj.model_dir,
config=run_config,
params={
'data_format': data_format,
})
# Set up training and evaluation input functions.
def train_input_fn():
"""Prepare data for training."""
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes use less memory. MNIST is a small
# enough dataset that we can easily shuffle the full epoch.
ds = dataset.train(flags_obj.data_dir)
ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)
# Iterate through the dataset a set number (`epochs_between_evals`) of times
# during each training session.
ds = ds.repeat(flags_obj.epochs_between_evals)
return ds
def eval_input_fn():
return dataset.test(flags_obj.data_dir).batch(
flags_obj.batch_size).make_one_shot_iterator().get_next()
# Set up hook that outputs training logs every 100 steps.
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks, model_dir=flags_obj.model_dir,
batch_size=flags_obj.batch_size)
# Train and evaluate model.
for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print('\nEvaluation results:\n\t%s\n' % eval_results)
if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
eval_results['accuracy']):
break
# Export the model
if flags_obj.export_dir is not None:
image = tf.placeholder(tf.float32, [None, 28, 28])
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'image': image,
})
mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn,
strip_default_attrs=True)