def main()

in models/official/retinanet/retinanet_segmentation_main.py [0:0]


def main(argv):
  del argv  # Unused.

  if FLAGS.use_tpu:
    tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    tpu_grpc_url = tpu_cluster_resolver.get_master()
    tf.Session.reset(tpu_grpc_url)

  if FLAGS.mode in ('train',
                    'train_and_eval') and FLAGS.training_file_pattern is None:
    raise RuntimeError('You must specify --training_file_pattern for training.')
  if FLAGS.mode in ('eval', 'train_and_eval'):
    if FLAGS.validation_file_pattern is None:
      raise RuntimeError('You must specify'
                         '--validation_file_pattern for evaluation.')

  # Parse hparams
  hparams = retinanet_segmentation_model.default_hparams()
  hparams.parse(FLAGS.hparams)

  params = dict(
      hparams.values(),
      num_shards=FLAGS.num_shards,
      num_examples_per_epoch=FLAGS.num_examples_per_epoch,
      use_tpu=FLAGS.use_tpu,
      resnet_checkpoint=FLAGS.resnet_checkpoint,
      mode=FLAGS.mode,
  )

  run_config = contrib_tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      evaluation_master='',
      model_dir=FLAGS.model_dir,
      keep_checkpoint_max=3,
      log_step_count_steps=FLAGS.iterations_per_loop,
      session_config=tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=False),
      tpu_config=contrib_tpu.TPUConfig(
          FLAGS.iterations_per_loop,
          FLAGS.num_shards,
          per_host_input_for_training=(
              contrib_tpu.InputPipelineConfig.PER_HOST_V2)))

  model_fn = retinanet_segmentation_model.segmentation_model_fn

  # TPU Estimator
  eval_params = dict(
      params,
      use_tpu=FLAGS.use_tpu,
      input_rand_hflip=False,
      resnet_checkpoint=None,
      is_training_bn=False,
  )
  if FLAGS.mode == 'train':
    train_estimator = contrib_tpu.TPUEstimator(
        model_fn=model_fn,
        use_tpu=FLAGS.use_tpu,
        train_batch_size=FLAGS.train_batch_size,
        config=run_config,
        params=params)

    train_estimator.train(
        input_fn=dataloader.SegmentationInputReader(
            FLAGS.training_file_pattern, is_training=True),
        max_steps=int((FLAGS.num_epochs * FLAGS.num_examples_per_epoch) /
                      FLAGS.train_batch_size),
        )

    if FLAGS.eval_after_training:
      # Run evaluation on CPU after training finishes.

      eval_estimator = contrib_tpu.TPUEstimator(
          model_fn=retinanet_segmentation_model.segmentation_model_fn,
          use_tpu=FLAGS.use_tpu,
          train_batch_size=FLAGS.train_batch_size,
          eval_batch_size=FLAGS.eval_batch_size,
          config=run_config,
          params=eval_params)
      eval_results = eval_estimator.evaluate(
          input_fn=dataloader.SegmentationInputReader(
              FLAGS.validation_file_pattern, is_training=False),
          steps=FLAGS.eval_samples//FLAGS.eval_batch_size)
      tf.logging.info('Eval results: %s' % eval_results)

  elif FLAGS.mode == 'eval':

    eval_estimator = contrib_tpu.TPUEstimator(
        model_fn=retinanet_segmentation_model.segmentation_model_fn,
        use_tpu=FLAGS.use_tpu,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        config=run_config,
        params=eval_params)

    def terminate_eval():
      tf.logging.info('Terminating eval after %d seconds of no checkpoints' %
                      FLAGS.eval_timeout)
      return True

    # Run evaluation when there's a new checkpoint
    for ckpt in contrib_training.checkpoints_iterator(
        FLAGS.model_dir,
        min_interval_secs=FLAGS.min_eval_interval,
        timeout=FLAGS.eval_timeout,
        timeout_fn=terminate_eval):

      tf.logging.info('Starting to evaluate.')
      try:
        # Note that if the eval_samples size is not fully divided by the
        # eval_batch_size. The remainder will be dropped and result in
        # differet evaluation performance than validating on the full set.
        eval_results = eval_estimator.evaluate(
            input_fn=dataloader.SegmentationInputReader(
                FLAGS.validation_file_pattern, is_training=False),
            steps=FLAGS.eval_samples//FLAGS.eval_batch_size)
        tf.logging.info('Eval results: %s' % eval_results)

        # Terminate eval job when final checkpoint is reached
        current_step = int(os.path.basename(ckpt).split('-')[1])
        total_step = int((FLAGS.num_epochs * FLAGS.num_examples_per_epoch) /
                         FLAGS.train_batch_size)
        if current_step >= total_step:
          tf.logging.info('Evaluation finished after training step %d' %
                          current_step)
          break

      except tf.errors.NotFoundError:
        # Since the coordinator is on a different job than the TPU worker,
        # sometimes the TPU worker does not finish initializing until long after
        # the CPU job tells it to start evaluating. In this case, the checkpoint
        # file could have been deleted already.
        tf.logging.info('Checkpoint %s no longer exists, skipping checkpoint' %
                        ckpt)

  elif FLAGS.mode == 'train_and_eval':
    train_estimator = contrib_tpu.TPUEstimator(
        model_fn=retinanet_segmentation_model.segmentation_model_fn,
        use_tpu=FLAGS.use_tpu,
        train_batch_size=FLAGS.train_batch_size,
        config=run_config,
        params=params)

    eval_estimator = contrib_tpu.TPUEstimator(
        model_fn=retinanet_segmentation_model.segmentation_model_fn,
        use_tpu=FLAGS.use_tpu,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        config=run_config,
        params=eval_params)
    for cycle in range(0, FLAGS.num_epochs):
      tf.logging.info('Starting training cycle, epoch: %d.' % cycle)
      train_estimator.train(
          input_fn=dataloader.SegmentationInputReader(
              FLAGS.training_file_pattern, is_training=True),
          steps=int(FLAGS.num_examples_per_epoch / FLAGS.train_batch_size))

      tf.logging.info('Starting evaluation cycle, epoch: {:d}.'.format(
          cycle + 1))
      # Run evaluation after training finishes.
      eval_results = eval_estimator.evaluate(
          input_fn=dataloader.SegmentationInputReader(
              FLAGS.validation_file_pattern, is_training=False),
          steps=FLAGS.eval_samples//FLAGS.eval_batch_size)
      tf.logging.info('Evaluation results: %s' % eval_results)

  else:
    tf.logging.info('Mode not found.')