in models/official/mnasnet/mnasnet_main.py [0:0]
def build_model_fn(features, labels, mode, params):
"""The model_fn for MnasNet to be used with TPUEstimator.
Args:
features: `Tensor` of batched images.
labels: `Tensor` of labels for the data samples
mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
params: `dict` of parameters passed to the model from the TPUEstimator,
`params['batch_size']` is always provided and should be used as the
effective batch size.
Returns:
A `TPUEstimatorSpec` for the model
"""
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
# This is essential, if using a keras-derived model.
tf.keras.backend.set_learning_phase(is_training)
if isinstance(features, dict):
features = features['feature']
if mode == tf.estimator.ModeKeys.PREDICT:
# Adds an identify node to help TFLite export.
features = tf.identity(features, 'float_image_input')
# In most cases, the default data format NCHW instead of NHWC should be
# used for a significant performance boost on GPU. NHWC should be used
# only if the network needs to be run on CPU since the pooling operations
# are only supported on NHWC. TPU uses XLA compiler to figure out best layout.
if params['data_format'] == 'channels_first':
assert not params['transpose_input'] # channels_first only for GPU
features = tf.transpose(features, [0, 3, 1, 2])
stats_shape = [3, 1, 1]
else:
stats_shape = [1, 1, 3]
if params['transpose_input'] and mode != tf.estimator.ModeKeys.PREDICT:
features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC
# Normalize the image to zero mean and unit variance.
features -= tf.constant(
imagenet_input.MEAN_RGB, shape=stats_shape, dtype=features.dtype)
features /= tf.constant(
imagenet_input.STDDEV_RGB, shape=stats_shape, dtype=features.dtype)
has_moving_average_decay = (params['moving_average_decay'] > 0)
tf.logging.info('Using open-source implementation for MnasNet definition.')
override_params = {}
if params['batch_norm_momentum']:
override_params['batch_norm_momentum'] = params['batch_norm_momentum']
if params['batch_norm_epsilon']:
override_params['batch_norm_epsilon'] = params['batch_norm_epsilon']
if params['dropout_rate']:
override_params['dropout_rate'] = params['dropout_rate']
if params['data_format']:
override_params['data_format'] = params['data_format']
if params['num_label_classes']:
override_params['num_classes'] = params['num_label_classes']
if params['depth_multiplier']:
override_params['depth_multiplier'] = params['depth_multiplier']
if params['depth_divisor']:
override_params['depth_divisor'] = params['depth_divisor']
if params['min_depth']:
override_params['min_depth'] = params['min_depth']
override_params['use_keras'] = params['use_keras']
def _build_model(model_name):
"""Build the model for a given model name."""
if model_name.startswith('mnasnet'):
return mnasnet_models.build_mnasnet_model(
features,
model_name=model_name,
training=is_training,
override_params=override_params)
elif model_name.startswith('mixnet'):
return mixnet_builder.build_model(
features,
model_name=model_name,
training=is_training,
override_params=override_params)
else:
raise ValueError('Unknown model name {}'.format(model_name))
if params['precision'] == 'bfloat16':
with tf.tpu.bfloat16_scope():
logits, _ = _build_model(params['model_name'])
logits = tf.cast(logits, tf.float32)
else: # params['precision'] == 'float32'
logits, _ = _build_model(params['model_name'])
if params['quantized_training']:
try:
from tensorflow.contrib import quantize # pylint: disable=g-import-not-at-top
except ImportError as e:
logging.exception('Quantized training is not supported in TensorFlow 2.x')
raise e
if is_training:
tf.logging.info('Adding fake quantization ops for training.')
quantize.create_training_graph(
quant_delay=int(params['steps_per_epoch'] *
FLAGS.quantization_delay_epochs))
else:
tf.logging.info('Adding fake quantization ops for evaluation.')
quantize.create_eval_graph()
if mode == tf.estimator.ModeKeys.PREDICT:
scaffold_fn = None
if FLAGS.export_moving_average:
# If the model is trained with moving average decay, to match evaluation
# metrics, we need to export the model using moving average variables.
restore_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
variables_to_restore = get_pretrained_variables_to_restore(
restore_checkpoint, load_moving_average=True)
tf.logging.info('Restoring from the latest checkpoint: %s',
restore_checkpoint)
tf.logging.info(str(variables_to_restore))
def restore_scaffold():
saver = tf.train.Saver(variables_to_restore)
return tf.train.Scaffold(saver=saver)
scaffold_fn = restore_scaffold
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
}
return tf.estimator.tpu.TPUEstimatorSpec(
mode=mode,
predictions=predictions,
export_outputs={
'classify': tf.estimator.export.PredictOutput(predictions)
},
scaffold_fn=scaffold_fn)
# If necessary, in the model_fn, use params['batch_size'] instead the batch
# size flags (--train_batch_size or --eval_batch_size).
batch_size = params['batch_size'] # pylint: disable=unused-variable
# Calculate loss, which includes softmax cross entropy and L2 regularization.
one_hot_labels = tf.one_hot(labels, params['num_label_classes'])
cross_entropy = tf.losses.softmax_cross_entropy(
logits=logits,
onehot_labels=one_hot_labels,
label_smoothing=params['label_smoothing'])
# Add weight decay to the loss for non-batch-normalization variables.
loss = cross_entropy + params['weight_decay'] * tf.add_n([
tf.nn.l2_loss(v)
for v in tf.trainable_variables()
if 'batch_normalization' not in v.name
])
global_step = tf.train.get_global_step()
if has_moving_average_decay:
ema = tf.train.ExponentialMovingAverage(
decay=params['moving_average_decay'], num_updates=global_step)
ema_vars = mnas_utils.get_ema_vars()
host_call = None
if is_training:
# Compute the current epoch and associated learning rate from global_step.
current_epoch = (
tf.cast(global_step, tf.float32) / params['steps_per_epoch'])
scaled_lr = params['base_learning_rate'] * (params['train_batch_size'] / 256.0) # pylint: disable=line-too-long
learning_rate = mnas_utils.build_learning_rate(scaled_lr, global_step,
params['steps_per_epoch'])
optimizer = mnas_utils.build_optimizer(learning_rate)
if params['use_tpu']:
# When using TPU, wrap the optimizer with CrossShardOptimizer which
# handles synchronization details between different TPU cores. To the
# user, this should look like regular synchronous training.
optimizer = tf.tpu.CrossShardOptimizer(optimizer)
if params['add_summaries']:
summary_writer = tf2.summary.create_file_writer(
FLAGS.model_dir, max_queue=params['iterations_per_loop'])
with summary_writer.as_default():
should_record = tf.equal(global_step % params['iterations_per_loop'],
0)
with tf2.summary.record_if(should_record):
tf2.summary.scalar('loss', loss, step=global_step)
tf2.summary.scalar('learning_rate', learning_rate, step=global_step)
tf2.summary.scalar('current_epoch', current_epoch, step=global_step)
# Batch normalization requires UPDATE_OPS to be added as a dependency to
# the train operation.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops + tf.summary.all_v2_summary_ops()):
train_op = optimizer.minimize(loss, global_step)
if has_moving_average_decay:
with tf.control_dependencies([train_op]):
train_op = ema.apply(ema_vars)
else:
train_op = None
eval_metrics = None
if mode == tf.estimator.ModeKeys.EVAL:
def metric_fn(labels, logits):
"""Evaluation metric function.
Evaluates accuracy.
This function is executed on the CPU and should not directly reference
any Tensors in the rest of the `model_fn`. To pass Tensors from the model
to the `metric_fn`, provide as part of the `eval_metrics`. See
https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec
for more information.
Arguments should match the list of `Tensor` objects passed as the second
element in the tuple passed to `eval_metrics`.
Args:
labels: `Tensor` with shape `[batch]`.
logits: `Tensor` with shape `[batch, num_classes]`.
Returns:
A dict of the metrics to return from evaluation.
"""
predictions = tf.argmax(logits, axis=1)
top_1_accuracy = tf.metrics.accuracy(labels, predictions)
in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
top_5_accuracy = tf.metrics.mean(in_top_5)
return {
'top_1_accuracy': top_1_accuracy,
'top_5_accuracy': top_5_accuracy,
}
eval_metrics = (metric_fn, [labels, logits])
num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
tf.logging.info('number of trainable parameters: {}'.format(num_params))
# Prepares scaffold_fn if needed.
scaffold_fn = None
if is_training and FLAGS.init_checkpoint:
variables_to_restore = get_pretrained_variables_to_restore(
FLAGS.init_checkpoint, has_moving_average_decay)
tf.logging.info('Initializing from pretrained checkpoint: %s',
FLAGS.init_checkpoint)
if FLAGS.use_tpu:
def init_scaffold():
tf.train.init_from_checkpoint(FLAGS.init_checkpoint,
variables_to_restore)
return tf.train.Scaffold()
scaffold_fn = init_scaffold
else:
tf.train.init_from_checkpoint(FLAGS.init_checkpoint, variables_to_restore)
restore_vars_dict = None
if not is_training and has_moving_average_decay:
# Load moving average variables for eval.
restore_vars_dict = ema.variables_to_restore(ema_vars)
def eval_scaffold():
saver = tf.train.Saver(restore_vars_dict)
return tf.train.Scaffold(saver=saver)
scaffold_fn = eval_scaffold
return tf.estimator.tpu.TPUEstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op,
host_call=host_call,
eval_metrics=eval_metrics,
scaffold_fn=scaffold_fn)