def run()

in example_zoo/tensorflow/models/keras_imagenet_main/official/resnet/keras/keras_imagenet_main.py [0:0]


def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
  """
  if flags_obj.enable_eager:
    tf.enable_eager_execution()

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'fp16':
    raise ValueError('dtype fp16 is not supported in Keras. Use the default '
                     'value(fp32).')

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
    input_fn = keras_common.get_synth_input_fn(
        height=imagenet_main.DEFAULT_IMAGE_SIZE,
        width=imagenet_main.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_main.NUM_CHANNELS,
        num_classes=imagenet_main.NUM_CLASSES,
        dtype=flags_core.get_tf_dtype(flags_obj))
  else:
    input_fn = imagenet_main.input_fn

  train_input_dataset = input_fn(is_training=True,
                                 data_dir=flags_obj.data_dir,
                                 batch_size=flags_obj.batch_size,
                                 num_epochs=flags_obj.train_epochs,
                                 parse_record_fn=parse_record_keras)

  eval_input_dataset = input_fn(is_training=False,
                                data_dir=flags_obj.data_dir,
                                batch_size=flags_obj.batch_size,
                                num_epochs=flags_obj.train_epochs,
                                parse_record_fn=parse_record_keras)

  strategy = distribution_utils.get_distribution_strategy(
      num_gpus=flags_obj.num_gpus,
      turn_off_distribution_strategy=flags_obj.turn_off_distribution_strategy)

  strategy_scope = keras_common.get_strategy_scope(strategy)

  with strategy_scope:
    optimizer = keras_common.get_optimizer()
    model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES)

    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=optimizer,
                  metrics=['sparse_categorical_accuracy'])

  time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
      learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])

  train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

  num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
                    flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
    tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
                      callbacks=[
                          time_callback,
                          lr_callback,
                          tensorboard_callback
                      ],
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      verbose=2)

  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=1)
  stats = keras_common.build_stats(history, eval_output, time_callback)
  return stats