in models/official/mobilenet/mobilenet.py [0:0]
def main(unused_argv):
del unused_argv # Unused
params = params_dict.ParamsDict({}, mobilenet_config.MOBILENET_RESTRICTIONS)
params = flags_to_params.override_params_from_input_flags(params, FLAGS)
params = params_dict.override_params_dict(
params, mobilenet_config.MOBILENET_CFG, is_strict=False)
params = params_dict.override_params_dict(
params, FLAGS.config_file, is_strict=True)
params = params_dict.override_params_dict(
params, FLAGS.params_override, is_strict=True)
input_perm = [0, 1, 2, 3]
output_perm = [0, 1, 2, 3]
batch_axis = 0
batch_size_per_shard = params.train_batch_size // params.num_cores
if params.transpose_enabled:
if batch_size_per_shard >= 64:
input_perm = [3, 0, 1, 2]
output_perm = [1, 2, 3, 0]
batch_axis = 3
else:
input_perm = [2, 0, 1, 3]
output_perm = [1, 2, 0, 3]
batch_axis = 2
additional_params = {
'input_perm': input_perm,
'output_perm': output_perm,
}
params = params_dict.override_params_dict(
params, additional_params, is_strict=False)
params.validate()
params.lock()
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
FLAGS.tpu if (FLAGS.tpu or params.use_tpu) else '',
zone=FLAGS.tpu_zone,
project=FLAGS.gcp_project)
if params.eval_total_size > 0:
eval_size = params.eval_total_size
else:
eval_size = params.num_eval_images
eval_steps = eval_size // params.eval_batch_size
iterations = (eval_steps if FLAGS.mode == 'eval' else
params.iterations_per_loop)
eval_batch_size = (None if FLAGS.mode == 'train' else
params.eval_batch_size)
per_host_input_for_training = (params.num_cores <= 8 if
FLAGS.mode == 'train' else True)
run_config = tf.estimator.tpu.RunConfig(
cluster=tpu_cluster_resolver,
model_dir=FLAGS.model_dir,
save_checkpoints_secs=FLAGS.save_checkpoints_secs,
save_summary_steps=FLAGS.save_summary_steps,
session_config=tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=FLAGS.log_device_placement),
tpu_config=tf.estimator.tpu.TPUConfig(
iterations_per_loop=iterations,
per_host_input_for_training=per_host_input_for_training))
inception_classifier = tf.estimator.tpu.TPUEstimator(
model_fn=model_fn,
use_tpu=params.use_tpu,
config=run_config,
params=params.as_dict(),
train_batch_size=params.train_batch_size,
eval_batch_size=eval_batch_size,
batch_axis=(batch_axis, 0))
# Input pipelines are slightly different (with regards to shuffling and
# preprocessing) between training and evaluation.
imagenet_train = supervised_images.InputPipeline(
is_training=True,
data_dir=FLAGS.data_dir)
imagenet_eval = supervised_images.InputPipeline(
is_training=False,
data_dir=FLAGS.data_dir)
if params.moving_average:
eval_hooks = [LoadEMAHook(FLAGS.model_dir)]
else:
eval_hooks = []
if FLAGS.mode == 'eval':
def terminate_eval():
absl.logging.info('%d seconds without new checkpoints have elapsed '
'... terminating eval' % FLAGS.eval_timeout)
return True
def get_next_checkpoint():
return tf.train.checkpoints_iterator(
FLAGS.model_dir,
min_interval_secs=params.min_eval_interval,
timeout=FLAGS.eval_timeout,
timeout_fn=terminate_eval)
for checkpoint in get_next_checkpoint():
absl.logging.info('Starting to evaluate.')
try:
eval_results = inception_classifier.evaluate(
input_fn=imagenet_eval.input_fn,
steps=eval_steps,
hooks=eval_hooks,
checkpoint_path=checkpoint)
absl.logging.info('Evaluation results: %s' % eval_results)
except tf.errors.NotFoundError:
# skip checkpoint if it gets deleted prior to evaluation
absl.logging.info('Checkpoint %s no longer exists ... skipping')
elif FLAGS.mode == 'train_and_eval':
for cycle in range(params.train_steps // params.train_steps_per_eval):
absl.logging.info('Starting training cycle %d.' % cycle)
inception_classifier.train(
input_fn=imagenet_train.input_fn,
steps=params.train_steps_per_eval)
absl.logging.info('Starting evaluation cycle %d .' % cycle)
eval_results = inception_classifier.evaluate(
input_fn=imagenet_eval.input_fn, steps=eval_steps, hooks=eval_hooks)
absl.logging.info('Evaluation results: %s' % eval_results)
else:
absl.logging.info('Starting training ...')
inception_classifier.train(
input_fn=imagenet_train.input_fn, steps=params.train_steps)
if FLAGS.export_dir:
absl.logging.info('Starting to export model with image input.')
inception_classifier.export_saved_model(
export_dir_base=FLAGS.export_dir,
serving_input_receiver_fn=image_serving_input_fn)
if FLAGS.tflite_export_dir:
absl.logging.info('Starting to export default TensorFlow model.')
savedmodel_dir = inception_classifier.export_saved_model(
export_dir_base=FLAGS.tflite_export_dir,
serving_input_receiver_fn=functools.partial(tensor_serving_input_fn, params)) # pylint: disable=line-too-long
absl.logging.info('Starting to export TFLite.')
converter = tf.lite.TFLiteConverter.from_saved_model(
savedmodel_dir,
output_arrays=['softmax_tensor'])
tflite_file_name = 'mobilenet.tflite'
if params.post_quantize:
converter.post_training_quantize = True
tflite_file_name = 'quantized_' + tflite_file_name
tflite_file = os.path.join(savedmodel_dir, tflite_file_name)
tflite_model = converter.convert()
tf.gfile.GFile(tflite_file, 'wb').write(tflite_model)