in models/official/efficientnet/main.py [0:0]
def model_fn(features, labels, mode, params):
"""The model_fn to be used with TPUEstimator.
Args:
features: `Tensor` of batched images.
labels: `Tensor` of one hot labels for the data samples
mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
params: `dict` of parameters passed to the model from the TPUEstimator,
`params['batch_size']` is always provided and should be used as the
effective batch size.
Returns:
A `TPUEstimatorSpec` for the model
"""
if isinstance(features, dict):
features = features['feature']
# In most cases, the default data format NCHW instead of NHWC should be
# used for a significant performance boost on GPU. NHWC should be used
# only if the network needs to be run on CPU since the pooling operations
# are only supported on NHWC. TPU uses XLA compiler to figure out best layout.
if FLAGS.data_format == 'channels_first':
assert not FLAGS.transpose_input # channels_first only for GPU
features = tf.transpose(features, [0, 3, 1, 2])
stats_shape = [3, 1, 1]
else:
stats_shape = [1, 1, 3]
input_image_size = FLAGS.input_image_size
if not input_image_size:
input_image_size = model_builder_factory.get_model_input_size(
FLAGS.model_name)
if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT:
features = tf.reshape(features,
[input_image_size, input_image_size, 3, -1])
features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
has_moving_average_decay = (FLAGS.moving_average_decay > 0)
# This is essential, if using a keras-derived model.
tf.keras.backend.set_learning_phase(is_training)
logging.info('Using open-source implementation.')
override_params = {}
if FLAGS.batch_norm_momentum is not None:
override_params['batch_norm_momentum'] = FLAGS.batch_norm_momentum
if FLAGS.batch_norm_epsilon is not None:
override_params['batch_norm_epsilon'] = FLAGS.batch_norm_epsilon
if FLAGS.dropout_rate is not None:
override_params['dropout_rate'] = FLAGS.dropout_rate
if FLAGS.survival_prob is not None:
override_params['survival_prob'] = FLAGS.survival_prob
if FLAGS.data_format:
override_params['data_format'] = FLAGS.data_format
if FLAGS.num_label_classes:
override_params['num_classes'] = FLAGS.num_label_classes
if FLAGS.depth_coefficient:
override_params['depth_coefficient'] = FLAGS.depth_coefficient
if FLAGS.width_coefficient:
override_params['width_coefficient'] = FLAGS.width_coefficient
if FLAGS.use_bfloat16:
override_params['use_bfloat16'] = FLAGS.use_bfloat16
def normalize_features(features, mean_rgb, stddev_rgb):
"""Normalize the image given the means and stddevs."""
features -= tf.constant(mean_rgb, shape=stats_shape, dtype=features.dtype)
features /= tf.constant(stddev_rgb, shape=stats_shape, dtype=features.dtype)
return features
def build_model():
"""Build model using the model_name given through the command line."""
model_builder = model_builder_factory.get_model_builder(FLAGS.model_name)
normalized_features = normalize_features(features, model_builder.MEAN_RGB,
model_builder.STDDEV_RGB)
logits, _ = model_builder.build_model(
normalized_features,
model_name=FLAGS.model_name,
training=is_training,
override_params=override_params,
model_dir=FLAGS.model_dir)
return logits
if params['use_bfloat16']:
with tf.tpu.bfloat16_scope():
logits = tf.cast(build_model(), tf.float32)
else:
logits = build_model()
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
}
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
export_outputs={
'classify': tf.estimator.export.PredictOutput(predictions)
})
# If necessary, in the model_fn, use params['batch_size'] instead the batch
# size flags (--train_batch_size or --eval_batch_size).
batch_size = params['batch_size'] # pylint: disable=unused-variable
# Calculate loss, which includes softmax cross entropy and L2 regularization.
cross_entropy = tf.losses.softmax_cross_entropy(
logits=logits,
onehot_labels=labels,
label_smoothing=FLAGS.label_smoothing)
# Add weight decay to the loss for non-batch-normalization variables.
loss = cross_entropy + FLAGS.weight_decay * tf.add_n(
[tf.nn.l2_loss(v) for v in tf.trainable_variables()
if 'batch_normalization' not in v.name])
global_step = tf.train.get_global_step()
if has_moving_average_decay:
ema = tf.train.ExponentialMovingAverage(
decay=FLAGS.moving_average_decay, num_updates=global_step)
ema_vars = utils.get_ema_vars()
host_call = None
restore_vars_dict = None
if is_training:
# Compute the current epoch and associated learning rate from global_step.
current_epoch = (
tf.cast(global_step, tf.float32) / params['steps_per_epoch'])
scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0)
logging.info('base_learning_rate = %f', FLAGS.base_learning_rate)
learning_rate = utils.build_learning_rate(
scaled_lr,
global_step,
params['steps_per_epoch'],
decay_epochs=FLAGS.lr_decay_epoch,
warmup_epochs=FLAGS.lr_warmup_epochs,
decay_factor=FLAGS.lr_decay_factor,
lr_decay_type=FLAGS.lr_schedule,
total_steps=FLAGS.train_steps)
optimizer = utils.build_optimizer(
learning_rate,
optimizer_name=FLAGS.optimizer,
lars_weight_decay=FLAGS.lars_weight_decay,
lars_epsilon=FLAGS.lars_epsilon)
if FLAGS.use_tpu:
# When using TPU, wrap the optimizer with CrossShardOptimizer which
# handles synchronization details between different TPU cores. To the
# user, this should look like regular synchronous training.
optimizer = tf.tpu.CrossShardOptimizer(optimizer)
# Batch normalization requires UPDATE_OPS to be added as a dependency to
# the train operation.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss, global_step)
if has_moving_average_decay:
with tf.control_dependencies([train_op]):
train_op = ema.apply(ema_vars)
if not FLAGS.skip_host_call:
def host_call_fn(gs, lr, ce):
"""Training host call. Creates scalar summaries for training metrics.
This function is executed on the CPU and should not directly reference
any Tensors in the rest of the `model_fn`. To pass Tensors from the
model to the `metric_fn`, provide as part of the `host_call`. See
https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec
for more information.
Arguments should match the list of `Tensor` objects passed as the second
element in the tuple passed to `host_call`.
Args:
gs: `Tensor with shape `[batch]` for the global_step
lr: `Tensor` with shape `[batch]` for the learning_rate.
ce: `Tensor` with shape `[batch]` for the current_epoch.
Returns:
List of summary ops to run on the CPU host.
"""
gs = gs[0]
# Host call fns are executed FLAGS.iterations_per_loop times after one
# TPU loop is finished, setting max_queue value to the same as number of
# iterations will make the summary writer only flush the data to storage
# once per loop.
with tf2.summary.create_file_writer(
FLAGS.model_dir, max_queue=FLAGS.iterations_per_loop).as_default():
with tf2.summary.record_if(True):
tf2.summary.scalar('learning_rate', lr[0], step=gs)
tf2.summary.scalar('current_epoch', ce[0], step=gs)
return tf.summary.all_v2_summary_ops()
# To log the loss, current learning rate, and epoch for Tensorboard, the
# summary op needs to be run on the host CPU via host_call. host_call
# expects [batch_size, ...] Tensors, thus reshape to introduce a batch
# dimension. These Tensors are implicitly concatenated to
# [params['batch_size']].
gs_t = tf.reshape(global_step, [1])
lr_t = tf.reshape(learning_rate, [1])
ce_t = tf.reshape(current_epoch, [1])
host_call = (host_call_fn, [gs_t, lr_t, ce_t])
else:
train_op = None
if has_moving_average_decay:
# Load moving average variables for eval.
restore_vars_dict = ema.variables_to_restore(ema_vars)
eval_metrics = None
if mode == tf.estimator.ModeKeys.EVAL:
def metric_fn(labels, logits):
"""Evaluation metric function. Evaluates accuracy.
This function is executed on the CPU and should not directly reference
any Tensors in the rest of the `model_fn`. To pass Tensors from the model
to the `metric_fn`, provide as part of the `eval_metrics`. See
https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec
for more information.
Arguments should match the list of `Tensor` objects passed as the second
element in the tuple passed to `eval_metrics`.
Args:
labels: `Tensor` with shape `[batch, num_classes]`.
logits: `Tensor` with shape `[batch, num_classes]`.
Returns:
A dict of the metrics to return from evaluation.
"""
labels = tf.argmax(labels, axis=1)
predictions = tf.argmax(logits, axis=1)
top_1_accuracy = tf.metrics.accuracy(labels, predictions)
in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
top_5_accuracy = tf.metrics.mean(in_top_5)
return {
'top_1_accuracy': top_1_accuracy,
'top_5_accuracy': top_5_accuracy,
}
eval_metrics = (metric_fn, [labels, logits])
num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
logging.info('number of trainable parameters: %d', num_params)
def _scaffold_fn():
saver = tf.train.Saver(restore_vars_dict)
return tf.train.Scaffold(saver=saver)
if has_moving_average_decay and not is_training:
# Only apply scaffold for eval jobs.
scaffold_fn = _scaffold_fn
else:
scaffold_fn = None
return tf.estimator.tpu.TPUEstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op,
host_call=host_call,
eval_metrics=eval_metrics,
scaffold_fn=scaffold_fn)