in example_zoo/tensorflow/models/keras_imagenet_main/official/resnet/resnet_run_loop.py [0:0]
def resnet_model_fn(features, labels, mode, model_class,
resnet_size, weight_decay, learning_rate_fn, momentum,
data_format, resnet_version, loss_scale,
loss_filter_fn=None, dtype=resnet_model.DEFAULT_DTYPE,
fine_tune=False):
"""Shared functionality for different resnet model_fns.
Initializes the ResnetModel representing the model layers
and uses that model to build the necessary EstimatorSpecs for
the `mode` in question. For training, this means building losses,
the optimizer, and the train op that get passed into the EstimatorSpec.
For evaluation and prediction, the EstimatorSpec is returned without
a train op, but with the necessary parameters for the given mode.
Args:
features: tensor representing input images
labels: tensor representing class labels for all input images
mode: current estimator mode; should be one of
`tf.estimator.ModeKeys.TRAIN`, `EVALUATE`, `PREDICT`
model_class: a class representing a TensorFlow model that has a __call__
function. We assume here that this is a subclass of ResnetModel.
resnet_size: A single integer for the size of the ResNet model.
weight_decay: weight decay loss rate used to regularize learned variables.
learning_rate_fn: function that returns the current learning rate given
the current global_step
momentum: momentum term used for optimization
data_format: Input format ('channels_last', 'channels_first', or None).
If set to None, the format is dependent on whether a GPU is available.
resnet_version: Integer representing which version of the ResNet network to
use. See README for details. Valid values: [1, 2]
loss_scale: The factor to scale the loss for numerical stability. A detailed
summary is present in the arg parser help text.
loss_filter_fn: function that takes a string variable name and returns
True if the var should be included in loss calculation, and False
otherwise. If None, batch_normalization variables will be excluded
from the loss.
dtype: the TensorFlow dtype to use for calculations.
fine_tune: If True only train the dense layers(final layers).
Returns:
EstimatorSpec parameterized according to the input params and the
current mode.
"""
# Generate a summary node for the images
tf.summary.image('images', features, max_outputs=6)
# Checks that features/images have same data type being used for calculations.
assert features.dtype == dtype
model = model_class(resnet_size, data_format, resnet_version=resnet_version,
dtype=dtype)
logits = model(features, mode == tf.estimator.ModeKeys.TRAIN)
# This acts as a no-op if the logits are already in fp32 (provided logits are
# not a SparseTensor). If dtype is is low precision, logits must be cast to
# fp32 for numerical stability.
logits = tf.cast(logits, tf.float32)
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
}
if mode == tf.estimator.ModeKeys.PREDICT:
# Return the predictions and the specification for serving a SavedModel
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
export_outputs={
'predict': tf.estimator.export.PredictOutput(predictions)
})
# Calculate loss, which includes softmax cross entropy and L2 regularization.
cross_entropy = tf.losses.sparse_softmax_cross_entropy(
logits=logits, labels=labels)
# Create a tensor named cross_entropy for logging purposes.
tf.identity(cross_entropy, name='cross_entropy')
tf.summary.scalar('cross_entropy', cross_entropy)
# If no loss_filter_fn is passed, assume we want the default behavior,
# which is that batch_normalization variables are excluded from loss.
def exclude_batch_norm(name):
return 'batch_normalization' not in name
loss_filter_fn = loss_filter_fn or exclude_batch_norm
# Add weight decay to the loss.
l2_loss = weight_decay * tf.add_n(
# loss is computed using fp32 for numerical stability.
[tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tf.trainable_variables()
if loss_filter_fn(v.name)])
tf.summary.scalar('l2_loss', l2_loss)
loss = cross_entropy + l2_loss
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_or_create_global_step()
learning_rate = learning_rate_fn(global_step)
# Create a tensor named learning_rate for logging purposes
tf.identity(learning_rate, name='learning_rate')
tf.summary.scalar('learning_rate', learning_rate)
optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate,
momentum=momentum
)
def _dense_grad_filter(gvs):
"""Only apply gradient updates to the final layer.
This function is used for fine tuning.
Args:
gvs: list of tuples with gradients and variable info
Returns:
filtered gradients so that only the dense layer remains
"""
return [(g, v) for g, v in gvs if 'dense' in v.name]
if loss_scale != 1:
# When computing fp16 gradients, often intermediate tensor values are
# so small, they underflow to 0. To avoid this, we multiply the loss by
# loss_scale to make these tensor values loss_scale times bigger.
scaled_grad_vars = optimizer.compute_gradients(loss * loss_scale)
if fine_tune:
scaled_grad_vars = _dense_grad_filter(scaled_grad_vars)
# Once the gradient computation is complete we can scale the gradients
# back to the correct scale before passing them to the optimizer.
unscaled_grad_vars = [(grad / loss_scale, var)
for grad, var in scaled_grad_vars]
minimize_op = optimizer.apply_gradients(unscaled_grad_vars, global_step)
else:
grad_vars = optimizer.compute_gradients(loss)
if fine_tune:
grad_vars = _dense_grad_filter(grad_vars)
minimize_op = optimizer.apply_gradients(grad_vars, global_step)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
train_op = tf.group(minimize_op, update_ops)
else:
train_op = None
accuracy = tf.metrics.accuracy(labels, predictions['classes'])
accuracy_top_5 = tf.metrics.mean(tf.nn.in_top_k(predictions=logits,
targets=labels,
k=5,
name='top_5_op'))
metrics = {'accuracy': accuracy,
'accuracy_top_5': accuracy_top_5}
# Create a tensor named train_accuracy for logging purposes
tf.identity(accuracy[1], name='train_accuracy')
tf.identity(accuracy_top_5[1], name='train_accuracy_top_5')
tf.summary.scalar('train_accuracy', accuracy[1])
tf.summary.scalar('train_accuracy_top_5', accuracy_top_5[1])
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metric_ops=metrics)