in example_zoo/tensorflow/models/keras_cifar_main/official/resnet/keras/keras_cifar_main.py [0:0]
def run(flags_obj):
"""Run ResNet Cifar-10 training and eval loop using native Keras APIs.
Args:
flags_obj: An object containing parsed flag values.
Raises:
ValueError: If fp16 is passed as it is not currently supported.
Returns:
Dictionary of training and eval stats.
"""
if flags_obj.enable_eager:
tf.enable_eager_execution()
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16':
raise ValueError('dtype fp16 is not supported in Keras. Use the default '
'value(fp32).')
data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
if flags_obj.use_synthetic_data:
input_fn = keras_common.get_synth_input_fn(
height=cifar_main.HEIGHT,
width=cifar_main.WIDTH,
num_channels=cifar_main.NUM_CHANNELS,
num_classes=cifar_main.NUM_CLASSES,
dtype=flags_core.get_tf_dtype(flags_obj))
else:
input_fn = cifar_main.input_fn
train_input_dataset = input_fn(
is_training=True,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
eval_input_dataset = input_fn(
is_training=False,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
strategy = distribution_utils.get_distribution_strategy(
num_gpus=flags_obj.num_gpus,
turn_off_distribution_strategy=flags_obj.turn_off_distribution_strategy)
strategy_scope = keras_common.get_strategy_scope(strategy)
with strategy_scope:
optimizer = keras_common.get_optimizer()
model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES)
model.compile(loss='categorical_crossentropy',
optimizer=optimizer,
metrics=['categorical_accuracy'])
time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
learning_rate_schedule, cifar_main.NUM_IMAGES['train'])
train_steps = cifar_main.NUM_IMAGES['train'] // flags_obj.batch_size
train_epochs = flags_obj.train_epochs
if flags_obj.train_steps:
train_steps = min(flags_obj.train_steps, train_steps)
train_epochs = 1
num_eval_steps = (cifar_main.NUM_IMAGES['validation'] //
flags_obj.batch_size)
validation_data = eval_input_dataset
if flags_obj.skip_eval:
tf.keras.backend.set_learning_phase(1)
num_eval_steps = None
validation_data = None
history = model.fit(train_input_dataset,
epochs=train_epochs,
steps_per_epoch=train_steps,
callbacks=[
time_callback,
lr_callback,
tensorboard_callback
],
validation_steps=num_eval_steps,
validation_data=validation_data,
verbose=2)
eval_output = None
if not flags_obj.skip_eval:
eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps,
verbose=1)
stats = keras_common.build_stats(history, eval_output, time_callback)
return stats