def main()

in models/official/mnasnet/mnasnet_main.py [0:0]


def main(unused_argv):
  params = params_dict.ParamsDict(
      mnasnet_config.MNASNET_CFG, mnasnet_config.MNASNET_RESTRICTIONS)
  params = params_dict.override_params_dict(
      params, FLAGS.config_file, is_strict=True)
  params = params_dict.override_params_dict(
      params, FLAGS.params_override, is_strict=True)

  params = flags_to_params.override_params_from_input_flags(params, FLAGS)

  additional_params = {
      'steps_per_epoch': params.num_train_images / params.train_batch_size,
      'quantized_training': FLAGS.quantized_training,
      'add_summaries': FLAGS.add_summaries,
  }

  params = params_dict.override_params_dict(
      params, additional_params, is_strict=False)

  params.validate()
  params.lock()

  if FLAGS.tpu or params.use_tpu:
    tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
  else:
    tpu_cluster_resolver = None

  if params.use_async_checkpointing:
    save_checkpoints_steps = None
  else:
    save_checkpoints_steps = max(100, params.iterations_per_loop)

  # Enables automatic outside compilation. Required in order to
  # automatically detect summary ops to run on CPU instead of TPU.
  tf.config.set_soft_device_placement(True)

  config = tf.estimator.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=FLAGS.model_dir,
      save_checkpoints_steps=save_checkpoints_steps,
      log_step_count_steps=FLAGS.log_step_count_steps,
      session_config=tf.ConfigProto(
          graph_options=tf.GraphOptions(
              rewrite_options=rewriter_config_pb2.RewriterConfig(
                  disable_meta_optimizer=True))),
      tpu_config=tf.estimator.tpu.TPUConfig(
          iterations_per_loop=params.iterations_per_loop,
          per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig
          .PER_HOST_V2))  # pylint: disable=line-too-long

  # Validates Flags.
  if params.precision == 'bfloat16' and params.use_keras:
    raise ValueError(
        'Keras layers do not have full support to bfloat16 activation training.'
        ' You have set precision as %s and use_keras as %s' %
        (params.precision, params.use_keras))

  # Initializes model parameters.
  mnasnet_est = tf.estimator.tpu.TPUEstimator(
      use_tpu=params.use_tpu,
      model_fn=build_model_fn,
      config=config,
      train_batch_size=params.train_batch_size,
      eval_batch_size=params.eval_batch_size,
      export_to_tpu=FLAGS.export_to_tpu,
      params=params.as_dict())

  if FLAGS.mode == 'export_only':
    export(mnasnet_est, FLAGS.export_dir, params, FLAGS.post_quantize)
    return

  # Input pipelines are slightly different (with regards to shuffling and
  # preprocessing) between training and evaluation.
  if FLAGS.bigtable_instance:
    tf.logging.info('Using Bigtable dataset, table %s', FLAGS.bigtable_table)
    select_train, select_eval = _select_tables_from_flags()
    imagenet_train, imagenet_eval = [imagenet_input.ImageNetBigtableInput(
        is_training=is_training,
        use_bfloat16=False,
        transpose_input=params.transpose_input,
        selection=selection) for (is_training, selection) in
                                     [(True, select_train),
                                      (False, select_eval)]]
  else:
    if FLAGS.data_dir == FAKE_DATA_DIR:
      tf.logging.info('Using fake dataset.')
    else:
      tf.logging.info('Using dataset: %s', FLAGS.data_dir)
    imagenet_train, imagenet_eval = [
        imagenet_input.ImageNetInput(
            is_training=is_training,
            data_dir=FLAGS.data_dir,
            transpose_input=params.transpose_input,
            cache=params.use_cache and is_training,
            image_size=params.input_image_size,
            num_parallel_calls=params.num_parallel_calls,
            use_bfloat16=(params.precision == 'bfloat16')) for is_training in [True, False]
    ]

  if FLAGS.mode == 'eval':
    eval_steps = params.num_eval_images // params.eval_batch_size
    # Run evaluation when there's a new checkpoint
    for ckpt in tf.train.checkpoints_iterator(
        FLAGS.model_dir, timeout=FLAGS.eval_timeout):
      tf.logging.info('Starting to evaluate.')
      try:
        start_timestamp = time.time()  # This time will include compilation time
        eval_results = mnasnet_est.evaluate(
            input_fn=imagenet_eval.input_fn,
            steps=eval_steps,
            checkpoint_path=ckpt)
        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info('Eval results: %s. Elapsed seconds: %d', eval_results,
                        elapsed_time)
        mnas_utils.archive_ckpt(eval_results, eval_results['top_1_accuracy'], ckpt)

        # Terminate eval job when final checkpoint is reached
        current_step = int(os.path.basename(ckpt).split('-')[1])
        if current_step >= params.train_steps:
          tf.logging.info('Evaluation finished after training step %d',
                          current_step)
          break

      except tf.errors.NotFoundError:
        # Since the coordinator is on a different job than the TPU worker,
        # sometimes the TPU worker does not finish initializing until long after
        # the CPU job tells it to start evaluating. In this case, the checkpoint
        # file could have been deleted already.
        tf.logging.info('Checkpoint %s no longer exists, skipping checkpoint',
                        ckpt)

    if FLAGS.export_dir:
      export(mnasnet_est, FLAGS.export_dir, params, FLAGS.post_quantize)
  else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
    try:
      current_step = tf.train.load_variable(FLAGS.model_dir,
                                            tf.GraphKeys.GLOBAL_STEP)
    except (TypeError, ValueError, tf.errors.NotFoundError):
      current_step = 0

    tf.logging.info(
        'Training for %d steps (%.2f epochs in total). Current'
        ' step %d.', params.train_steps,
        params.train_steps / params.steps_per_epoch, current_step)

    start_timestamp = time.time()  # This time will include compilation time

    if FLAGS.mode == 'train':
      hooks = []
      if params.use_async_checkpointing:
        try:
          from tensorflow.contrib.tpu.python.tpu import async_checkpoint  # pylint: disable=g-import-not-at-top
        except ImportError as e:
          logging.exception(
              'Async checkpointing is not supported in TensorFlow 2.x')
          raise e

        hooks.append(
            async_checkpoint.AsyncCheckpointSaverHook(
                checkpoint_dir=FLAGS.model_dir,
                save_steps=max(100, params.iterations_per_loop)))
      mnasnet_est.train(
          input_fn=imagenet_train.input_fn,
          max_steps=params.train_steps,
          hooks=hooks)

    else:
      assert FLAGS.mode == 'train_and_eval'
      while current_step < params.train_steps:
        # Train for up to steps_per_eval number of steps.
        # At the end of training, a checkpoint will be written to --model_dir.
        next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                              params.train_steps)
        mnasnet_est.train(
            input_fn=imagenet_train.input_fn, max_steps=next_checkpoint)
        current_step = next_checkpoint

        tf.logging.info('Finished training up to step %d. Elapsed seconds %d.',
                        next_checkpoint, int(time.time() - start_timestamp))

        # Evaluate the model on the most recent model in --model_dir.
        # Since evaluation happens in batches of --eval_batch_size, some images
        # may be excluded modulo the batch size. As long as the batch size is
        # consistent, the evaluated images are also consistent.
        tf.logging.info('Starting to evaluate.')
        eval_results = mnasnet_est.evaluate(
            input_fn=imagenet_eval.input_fn,
            steps=params.num_eval_images // params.eval_batch_size)
        tf.logging.info('Eval results at step %d: %s', next_checkpoint,
                        eval_results)
        ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
        mnas_utils.archive_ckpt(eval_results, eval_results['top_1_accuracy'], ckpt)

      elapsed_time = int(time.time() - start_timestamp)
      tf.logging.info('Finished training up to step %d. Elapsed seconds %d.',
                      params.train_steps, elapsed_time)
      if FLAGS.export_dir:
        export(mnasnet_est, FLAGS.export_dir, params, FLAGS.post_quantize)