in example_zoo/tensorflow/models/keras_imagenet_main/official/resnet/keras/keras_imagenet_main.py [0:0]
def run(flags_obj):
"""Run ResNet ImageNet training and eval loop using native Keras APIs.
Args:
flags_obj: An object containing parsed flag values.
Raises:
ValueError: If fp16 is passed as it is not currently supported.
"""
if flags_obj.enable_eager:
tf.enable_eager_execution()
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16':
raise ValueError('dtype fp16 is not supported in Keras. Use the default '
'value(fp32).')
data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
# pylint: disable=protected-access
if flags_obj.use_synthetic_data:
input_fn = keras_common.get_synth_input_fn(
height=imagenet_main.DEFAULT_IMAGE_SIZE,
width=imagenet_main.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_main.NUM_CHANNELS,
num_classes=imagenet_main.NUM_CLASSES,
dtype=flags_core.get_tf_dtype(flags_obj))
else:
input_fn = imagenet_main.input_fn
train_input_dataset = input_fn(is_training=True,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
eval_input_dataset = input_fn(is_training=False,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
strategy = distribution_utils.get_distribution_strategy(
num_gpus=flags_obj.num_gpus,
turn_off_distribution_strategy=flags_obj.turn_off_distribution_strategy)
strategy_scope = keras_common.get_strategy_scope(strategy)
with strategy_scope:
optimizer = keras_common.get_optimizer()
model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES)
model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer,
metrics=['sparse_categorical_accuracy'])
time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])
train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
train_epochs = flags_obj.train_epochs
if flags_obj.train_steps:
train_steps = min(flags_obj.train_steps, train_steps)
train_epochs = 1
num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
flags_obj.batch_size)
validation_data = eval_input_dataset
if flags_obj.skip_eval:
# Only build the training graph. This reduces memory usage introduced by
# control flow ops in layers that have different implementations for
# training and inference (e.g., batch norm).
tf.keras.backend.set_learning_phase(1)
num_eval_steps = None
validation_data = None
history = model.fit(train_input_dataset,
epochs=train_epochs,
steps_per_epoch=train_steps,
callbacks=[
time_callback,
lr_callback,
tensorboard_callback
],
validation_steps=num_eval_steps,
validation_data=validation_data,
verbose=2)
eval_output = None
if not flags_obj.skip_eval:
eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps,
verbose=1)
stats = keras_common.build_stats(history, eval_output, time_callback)
return stats