def main()

in models/official/mobilenet/mobilenet.py [0:0]


def main(unused_argv):
  del unused_argv  # Unused

  params = params_dict.ParamsDict({}, mobilenet_config.MOBILENET_RESTRICTIONS)
  params = flags_to_params.override_params_from_input_flags(params, FLAGS)
  params = params_dict.override_params_dict(
      params, mobilenet_config.MOBILENET_CFG, is_strict=False)
  params = params_dict.override_params_dict(
      params, FLAGS.config_file, is_strict=True)
  params = params_dict.override_params_dict(
      params, FLAGS.params_override, is_strict=True)

  input_perm = [0, 1, 2, 3]
  output_perm = [0, 1, 2, 3]

  batch_axis = 0
  batch_size_per_shard = params.train_batch_size // params.num_cores
  if params.transpose_enabled:
    if batch_size_per_shard >= 64:
      input_perm = [3, 0, 1, 2]
      output_perm = [1, 2, 3, 0]
      batch_axis = 3
    else:
      input_perm = [2, 0, 1, 3]
      output_perm = [1, 2, 0, 3]
      batch_axis = 2

  additional_params = {
      'input_perm': input_perm,
      'output_perm': output_perm,
  }
  params = params_dict.override_params_dict(
      params, additional_params, is_strict=False)

  params.validate()
  params.lock()

  tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
      FLAGS.tpu if (FLAGS.tpu or params.use_tpu) else '',
      zone=FLAGS.tpu_zone,
      project=FLAGS.gcp_project)

  if params.eval_total_size > 0:
    eval_size = params.eval_total_size
  else:
    eval_size = params.num_eval_images
  eval_steps = eval_size // params.eval_batch_size

  iterations = (eval_steps if FLAGS.mode == 'eval' else
                params.iterations_per_loop)

  eval_batch_size = (None if FLAGS.mode == 'train' else
                     params.eval_batch_size)

  per_host_input_for_training = (params.num_cores <= 8 if
                                 FLAGS.mode == 'train' else True)

  run_config = tf.estimator.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=FLAGS.model_dir,
      save_checkpoints_secs=FLAGS.save_checkpoints_secs,
      save_summary_steps=FLAGS.save_summary_steps,
      session_config=tf.ConfigProto(
          allow_soft_placement=True,
          log_device_placement=FLAGS.log_device_placement),
      tpu_config=tf.estimator.tpu.TPUConfig(
          iterations_per_loop=iterations,
          per_host_input_for_training=per_host_input_for_training))

  inception_classifier = tf.estimator.tpu.TPUEstimator(
      model_fn=model_fn,
      use_tpu=params.use_tpu,
      config=run_config,
      params=params.as_dict(),
      train_batch_size=params.train_batch_size,
      eval_batch_size=eval_batch_size,
      batch_axis=(batch_axis, 0))

  # Input pipelines are slightly different (with regards to shuffling and
  # preprocessing) between training and evaluation.
  imagenet_train = supervised_images.InputPipeline(
      is_training=True,
      data_dir=FLAGS.data_dir)
  imagenet_eval = supervised_images.InputPipeline(
      is_training=False,
      data_dir=FLAGS.data_dir)

  if params.moving_average:
    eval_hooks = [LoadEMAHook(FLAGS.model_dir)]
  else:
    eval_hooks = []

  if FLAGS.mode == 'eval':
    def terminate_eval():
      absl.logging.info('%d seconds without new checkpoints have elapsed '
                        '... terminating eval' % FLAGS.eval_timeout)
      return True

    def get_next_checkpoint():
      return tf.train.checkpoints_iterator(
          FLAGS.model_dir,
          min_interval_secs=params.min_eval_interval,
          timeout=FLAGS.eval_timeout,
          timeout_fn=terminate_eval)

    for checkpoint in get_next_checkpoint():
      absl.logging.info('Starting to evaluate.')
      try:
        eval_results = inception_classifier.evaluate(
            input_fn=imagenet_eval.input_fn,
            steps=eval_steps,
            hooks=eval_hooks,
            checkpoint_path=checkpoint)
        absl.logging.info('Evaluation results: %s' % eval_results)
      except tf.errors.NotFoundError:
        # skip checkpoint if it gets deleted prior to evaluation
        absl.logging.info('Checkpoint %s no longer exists ... skipping')

  elif FLAGS.mode == 'train_and_eval':
    for cycle in range(params.train_steps // params.train_steps_per_eval):
      absl.logging.info('Starting training cycle %d.' % cycle)
      inception_classifier.train(
          input_fn=imagenet_train.input_fn,
          steps=params.train_steps_per_eval)

      absl.logging.info('Starting evaluation cycle %d .' % cycle)
      eval_results = inception_classifier.evaluate(
          input_fn=imagenet_eval.input_fn, steps=eval_steps, hooks=eval_hooks)
      absl.logging.info('Evaluation results: %s' % eval_results)

  else:
    absl.logging.info('Starting training ...')
    inception_classifier.train(
        input_fn=imagenet_train.input_fn, steps=params.train_steps)

  if FLAGS.export_dir:
    absl.logging.info('Starting to export model with image input.')
    inception_classifier.export_saved_model(
        export_dir_base=FLAGS.export_dir,
        serving_input_receiver_fn=image_serving_input_fn)

  if FLAGS.tflite_export_dir:
    absl.logging.info('Starting to export default TensorFlow model.')
    savedmodel_dir = inception_classifier.export_saved_model(
        export_dir_base=FLAGS.tflite_export_dir,
        serving_input_receiver_fn=functools.partial(tensor_serving_input_fn, params))  # pylint: disable=line-too-long
    absl.logging.info('Starting to export TFLite.')
    converter = tf.lite.TFLiteConverter.from_saved_model(
        savedmodel_dir,
        output_arrays=['softmax_tensor'])
    tflite_file_name = 'mobilenet.tflite'
    if params.post_quantize:
      converter.post_training_quantize = True
      tflite_file_name = 'quantized_' + tflite_file_name
    tflite_file = os.path.join(savedmodel_dir, tflite_file_name)
    tflite_model = converter.convert()
    tf.gfile.GFile(tflite_file, 'wb').write(tflite_model)