def main()

in models/official/efficientnet/main.py [0:0]


def main(unused_argv):

  input_image_size = FLAGS.input_image_size
  if not input_image_size:
    input_image_size = model_builder_factory.get_model_input_size(
        FLAGS.model_name)

  if FLAGS.holdout_shards:
    holdout_images = int(FLAGS.num_train_images * FLAGS.holdout_shards / 1024.0)
    FLAGS.num_train_images -= holdout_images
    if FLAGS.eval_name and 'test' in FLAGS.eval_name:
      FLAGS.holdout_shards = None  # do not use holdout if eval test set.
    else:
      FLAGS.num_eval_images = holdout_images

  # For imagenet dataset, include background label if number of output classes
  # is 1001
  include_background_label = (FLAGS.num_label_classes == 1001)

  if FLAGS.tpu or FLAGS.use_tpu:
    tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu,
        zone=FLAGS.tpu_zone,
        project=FLAGS.gcp_project)
  else:
    tpu_cluster_resolver = None

  if FLAGS.use_async_checkpointing:
    save_checkpoints_steps = None
  else:
    save_checkpoints_steps = max(100, FLAGS.iterations_per_loop)
  config = tf.estimator.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=FLAGS.model_dir,
      save_checkpoints_steps=save_checkpoints_steps,
      log_step_count_steps=FLAGS.log_step_count_steps,
      session_config=tf.ConfigProto(
          graph_options=tf.GraphOptions(
              rewrite_options=rewriter_config_pb2.RewriterConfig(
                  disable_meta_optimizer=True))),
      tpu_config=tf.estimator.tpu.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          tpu_job_name=FLAGS.tpu_job_name,
          per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig
          .PER_HOST_V2))  # pylint: disable=line-too-long
  # Initializes model parameters.
  params = dict(
      steps_per_epoch=FLAGS.num_train_images / FLAGS.train_batch_size,
      use_bfloat16=FLAGS.use_bfloat16)
  est = tf.estimator.tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=config,
      train_batch_size=FLAGS.train_batch_size,
      eval_batch_size=FLAGS.eval_batch_size,
      export_to_tpu=FLAGS.export_to_tpu,
      params=params)

  if (FLAGS.model_name.startswith('efficientnet-lite') or
      FLAGS.model_name.startswith('efficientnet-edgetpu')):
    # lite or edgetpu use binlinear for easier post-quantization.
    resize_method = tf.image.ResizeMethod.BILINEAR
  else:
    resize_method = None
  # Input pipelines are slightly different (with regards to shuffling and
  # preprocessing) between training and evaluation.
  def build_imagenet_input(is_training):
    """Generate ImageNetInput for training and eval."""
    if FLAGS.bigtable_instance:
      logging.info('Using Bigtable dataset, table %s', FLAGS.bigtable_table)
      select_train, select_eval = _select_tables_from_flags()
      return imagenet_input.ImageNetBigtableInput(
          is_training=is_training,
          use_bfloat16=FLAGS.use_bfloat16,
          transpose_input=FLAGS.transpose_input,
          selection=select_train if is_training else select_eval,
          num_label_classes=FLAGS.num_label_classes,
          include_background_label=include_background_label,
          augment_name=FLAGS.augment_name,
          mixup_alpha=FLAGS.mixup_alpha,
          randaug_num_layers=FLAGS.randaug_num_layers,
          randaug_magnitude=FLAGS.randaug_magnitude,
          resize_method=resize_method)
    else:
      if FLAGS.data_dir == FAKE_DATA_DIR:
        logging.info('Using fake dataset.')
      else:
        logging.info('Using dataset: %s', FLAGS.data_dir)

      return imagenet_input.ImageNetInput(
          is_training=is_training,
          data_dir=FLAGS.data_dir,
          transpose_input=FLAGS.transpose_input,
          cache=FLAGS.use_cache and is_training,
          image_size=input_image_size,
          num_parallel_calls=FLAGS.num_parallel_calls,
          use_bfloat16=FLAGS.use_bfloat16,
          num_label_classes=FLAGS.num_label_classes,
          include_background_label=include_background_label,
          augment_name=FLAGS.augment_name,
          mixup_alpha=FLAGS.mixup_alpha,
          randaug_num_layers=FLAGS.randaug_num_layers,
          randaug_magnitude=FLAGS.randaug_magnitude,
          resize_method=resize_method,
          holdout_shards=FLAGS.holdout_shards)

  imagenet_train = build_imagenet_input(is_training=True)
  imagenet_eval = build_imagenet_input(is_training=False)

  if FLAGS.mode == 'eval':
    eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size
    # Run evaluation when there's a new checkpoint
    for ckpt in tf.train.checkpoints_iterator(
        FLAGS.model_dir, timeout=FLAGS.eval_timeout):
      logging.info('Starting to evaluate.')
      try:
        start_timestamp = time.time()  # This time will include compilation time
        eval_results = est.evaluate(
            input_fn=imagenet_eval.input_fn,
            steps=eval_steps,
            checkpoint_path=ckpt,
            name=FLAGS.eval_name)
        elapsed_time = int(time.time() - start_timestamp)
        logging.info('Eval results: %s. Elapsed seconds: %d',
                     eval_results, elapsed_time)
        if FLAGS.archive_ckpt:
          utils.archive_ckpt(eval_results, eval_results['top_1_accuracy'], ckpt)

        # Terminate eval job when final checkpoint is reached
        try:
          current_step = int(os.path.basename(ckpt).split('-')[1])
        except IndexError:
          logging.info('%s has no global step info: stop!', ckpt)
          break

        if current_step >= FLAGS.train_steps:
          logging.info(
              'Evaluation finished after training step %d', current_step)
          break

      except tf.errors.NotFoundError:
        # Since the coordinator is on a different job than the TPU worker,
        # sometimes the TPU worker does not finish initializing until long after
        # the CPU job tells it to start evaluating. In this case, the checkpoint
        # file could have been deleted already.
        logging.info(
            'Checkpoint %s no longer exists, skipping checkpoint', ckpt)
  else:   # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
    current_step = estimator._load_global_step_from_checkpoint_dir(FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long

    logging.info(
        'Training for %d steps (%.2f epochs in total). Current'
        ' step %d.', FLAGS.train_steps,
        FLAGS.train_steps / params['steps_per_epoch'], current_step)

    start_timestamp = time.time()  # This time will include compilation time

    if FLAGS.mode == 'train':
      hooks = []
      if FLAGS.use_async_checkpointing:
        try:
          from tensorflow.contrib.tpu.python.tpu import async_checkpoint  # pylint: disable=g-import-not-at-top
        except ImportError as e:
          logging.exception(
              'Async checkpointing is not supported in TensorFlow 2.x')
          raise e

        hooks.append(
            async_checkpoint.AsyncCheckpointSaverHook(
                checkpoint_dir=FLAGS.model_dir,
                save_steps=max(100, FLAGS.iterations_per_loop)))
      est.train(
          input_fn=imagenet_train.input_fn,
          max_steps=FLAGS.train_steps,
          hooks=hooks)

    else:
      assert FLAGS.mode == 'train_and_eval'
      while current_step < FLAGS.train_steps:
        # Train for up to steps_per_eval number of steps.
        # At the end of training, a checkpoint will be written to --model_dir.
        next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                              FLAGS.train_steps)
        est.train(input_fn=imagenet_train.input_fn, max_steps=next_checkpoint)
        current_step = next_checkpoint

        logging.info('Finished training up to step %d. Elapsed seconds %d.',
                     next_checkpoint, int(time.time() - start_timestamp))

        # Evaluate the model on the most recent model in --model_dir.
        # Since evaluation happens in batches of --eval_batch_size, some images
        # may be excluded modulo the batch size. As long as the batch size is
        # consistent, the evaluated images are also consistent.
        logging.info('Starting to evaluate.')
        eval_results = est.evaluate(
            input_fn=imagenet_eval.input_fn,
            steps=FLAGS.num_eval_images // FLAGS.eval_batch_size,
            name=FLAGS.eval_name)
        logging.info('Eval results at step %d: %s',
                     next_checkpoint, eval_results)
        ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
        if FLAGS.archive_ckpt:
          utils.archive_ckpt(eval_results, eval_results['top_1_accuracy'], ckpt)

      elapsed_time = int(time.time() - start_timestamp)
      logging.info('Finished training up to step %d. Elapsed seconds %d.',
                   FLAGS.train_steps, elapsed_time)
  if FLAGS.export_dir:
    export(est, FLAGS.export_dir, input_image_size)