in models/official/resnet/resnet_main.py [0:0]
def resnet_model_fn(features, labels, mode, params):
"""The model_fn for ResNet to be used with TPUEstimator.
Args:
features: `Tensor` of batched images. If transpose_input is enabled, it
is transposed to device layout and reshaped to 1D tensor.
labels: `Tensor` of labels for the data samples
mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
params: `dict` of parameters passed to the model from the TPUEstimator,
`params['batch_size']` is always provided and should be used as the
effective batch size.
Returns:
A `TPUEstimatorSpec` for the model
"""
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
if isinstance(features, dict):
features = features['feature']
# In most cases, the default data format NCHW instead of NHWC should be
# used for a significant performance boost on GPU/TPU. NHWC should be used
# only if the network needs to be run on CPU since the pooling operations
# are only supported on NHWC.
if params['data_format'] == 'channels_first':
assert not params['transpose_input'] # channels_first only for GPU
features = tf.transpose(features, [0, 3, 1, 2])
if params['transpose_input'] and mode != tf.estimator.ModeKeys.PREDICT:
image_size = tf.sqrt(tf.shape(features)[0] / (3 * tf.shape(labels)[0]))
features = tf.reshape(features, [image_size, image_size, 3, -1])
features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC
# Normalize the image to zero mean and unit variance.
features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype)
features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype)
# DropBlock keep_prob for the 4 block groups of ResNet architecture.
# None means applying no DropBlock at the corresponding block group.
dropblock_keep_probs = [None] * 4
if params['dropblock_groups']:
# Scheduled keep_prob for DropBlock.
train_steps = tf.cast(params['train_steps'], tf.float32)
current_step = tf.cast(tf.train.get_global_step(), tf.float32)
current_ratio = current_step / train_steps
dropblock_keep_prob = (1 - current_ratio * (
1 - params['dropblock_keep_prob']))
# Computes DropBlock keep_prob for different block groups of ResNet.
dropblock_groups = [int(x) for x in params['dropblock_groups'].split(',')]
for block_group in dropblock_groups:
if block_group < 1 or block_group > 4:
raise ValueError(
'dropblock_groups should be a comma separated list of integers '
'between 1 and 4 (dropblcok_groups: {}).'
.format(params['dropblock_groups']))
dropblock_keep_probs[block_group - 1] = 1 - (
(1 - dropblock_keep_prob) / 4.0**(4 - block_group))
has_moving_average_decay = (params['moving_average_decay'] > 0)
if has_moving_average_decay and params['bn_momentum'] > 0:
raise ValueError(
'Should not use exponential moving average and batch norm momentum')
# This nested function allows us to avoid duplicating the logic which
# builds the network, for different values of --precision.
def build_network():
network = resnet_model.resnet(
resnet_depth=params['resnet_depth'],
num_classes=params['num_label_classes'],
dropblock_size=params['dropblock_size'],
dropblock_keep_probs=dropblock_keep_probs,
pre_activation=params['pre_activation'],
norm_act_layer=params['norm_act_layer'],
data_format=params['data_format'],
se_ratio=params['se_ratio'],
drop_connect_rate=params['drop_connect_rate'],
use_resnetd_stem=params['use_resnetd_stem'],
resnetd_shortcut=params['resnetd_shortcut'],
replace_stem_max_pool=params['replace_stem_max_pool'],
dropout_rate=params['dropout_rate'],
bn_momentum=params['bn_momentum'])
return network(
inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
if params['precision'] == 'bfloat16':
with tf.tpu.bfloat16_scope():
logits = build_network()
logits = tf.cast(logits, tf.float32)
elif params['precision'] == 'float32':
logits = build_network()
if mode == tf.estimator.ModeKeys.PREDICT:
scaffold_fn = None
if FLAGS.export_moving_average:
# If the model is trained with moving average decay, to match evaluation
# metrics, we need to export the model using moving average variables.
restore_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
variables_to_restore = get_pretrained_variables_to_restore(
restore_checkpoint, load_moving_average=True)
tf.logging.info('Restoring from the latest checkpoint: %s',
restore_checkpoint)
tf.logging.info(str(variables_to_restore))
def restore_scaffold():
saver = tf.train.Saver(variables_to_restore)
return tf.train.Scaffold(saver=saver)
scaffold_fn = restore_scaffold
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
}
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
export_outputs={
'classify': tf.estimator.export.PredictOutput(predictions)},
scaffold_fn=scaffold_fn)
# If necessary, in the model_fn, use params['batch_size'] instead the batch
# size flags (--train_batch_size or --eval_batch_size).
batch_size = params['batch_size'] # pylint: disable=unused-variable
# Calculate loss, which includes softmax cross entropy and L2 regularization.
one_hot_labels = tf.one_hot(labels, params['num_label_classes'])
cross_entropy = tf.losses.softmax_cross_entropy(
logits=logits,
onehot_labels=one_hot_labels,
label_smoothing=params['label_smoothing'])
# Add weight decay to the loss for non-batch-normalization variables.
if params['enable_lars']:
loss = cross_entropy
else:
loss = cross_entropy + params['weight_decay'] * tf.add_n([
tf.nn.l2_loss(v)
for v in tf.trainable_variables()
if 'batch_normalization' not in v.name and 'evonorm' not in v.name
])
global_step = tf.train.get_global_step()
if has_moving_average_decay:
ema = tf.train.ExponentialMovingAverage(
decay=params['moving_average_decay'], num_updates=global_step)
ema_vars = get_ema_vars()
host_call = None
if mode == tf.estimator.ModeKeys.TRAIN:
# Compute the current epoch and associated learning rate from global_step.
global_step = tf.train.get_global_step()
steps_per_epoch = params['num_train_images'] / params['train_batch_size']
current_epoch = (tf.cast(global_step, tf.float32) /
steps_per_epoch)
# LARS is a large batch optimizer. LARS enables higher accuracy at batch 16K
# and larger batch sizes.
if params['enable_lars']:
learning_rate = 0.0
optimizer = lars_util.init_lars_optimizer(current_epoch, params)
else:
learning_rate = learning_rate_schedule(params, current_epoch)
optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate,
momentum=params['momentum'],
use_nesterov=True)
if params['use_tpu']:
# When using TPU, wrap the optimizer with CrossShardOptimizer which
# handles synchronization details between different TPU cores. To the
# user, this should look like regular synchronous training.
optimizer = tf.tpu.CrossShardOptimizer(optimizer)
# Batch normalization requires UPDATE_OPS to be added as a dependency to
# the train operation.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss, global_step)
if has_moving_average_decay:
with tf.control_dependencies([train_op]):
train_op = ema.apply(ema_vars)
if not params['skip_host_call']:
def host_call_fn(gs, loss, lr, ce):
"""Training host call. Creates scalar summaries for training metrics.
This function is executed on the CPU and should not directly reference
any Tensors in the rest of the `model_fn`. To pass Tensors from the
model to the `metric_fn`, provide as part of the `host_call`. See
https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec
for more information.
Arguments should match the list of `Tensor` objects passed as the second
element in the tuple passed to `host_call`.
Args:
gs: `Tensor with shape `[batch]` for the global_step
loss: `Tensor` with shape `[batch]` for the training loss.
lr: `Tensor` with shape `[batch]` for the learning_rate.
ce: `Tensor` with shape `[batch]` for the current_epoch.
Returns:
List of summary ops to run on the CPU host.
"""
gs = gs[0]
# Host call fns are executed params['iterations_per_loop'] times after
# one TPU loop is finished, setting max_queue value to the same as
# number of iterations will make the summary writer only flush the data
# to storage once per loop.
with tf2.summary.create_file_writer(
FLAGS.model_dir,
max_queue=params['iterations_per_loop']).as_default():
with tf2.summary.record_if(True):
tf2.summary.scalar('loss', loss[0], step=gs)
tf2.summary.scalar('learning_rate', lr[0], step=gs)
tf2.summary.scalar('current_epoch', ce[0], step=gs)
return tf.summary.all_v2_summary_ops()
# To log the loss, current learning rate, and epoch for Tensorboard, the
# summary op needs to be run on the host CPU via host_call. host_call
# expects [batch_size, ...] Tensors, thus reshape to introduce a batch
# dimension. These Tensors are implicitly concatenated to
# [params['batch_size']].
gs_t = tf.reshape(global_step, [1])
loss_t = tf.reshape(loss, [1])
lr_t = tf.reshape(learning_rate, [1])
ce_t = tf.reshape(current_epoch, [1])
host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t])
else:
train_op = None
eval_metrics = None
if mode == tf.estimator.ModeKeys.EVAL:
def metric_fn(labels, logits):
"""Evaluation metric function. Evaluates accuracy.
This function is executed on the CPU and should not directly reference
any Tensors in the rest of the `model_fn`. To pass Tensors from the model
to the `metric_fn`, provide as part of the `eval_metrics`. See
https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec
for more information.
Arguments should match the list of `Tensor` objects passed as the second
element in the tuple passed to `eval_metrics`.
Args:
labels: `Tensor` with shape `[batch]`.
logits: `Tensor` with shape `[batch, num_classes]`.
Returns:
A dict of the metrics to return from evaluation.
"""
predictions = tf.argmax(logits, axis=1)
top_1_accuracy = tf.metrics.accuracy(labels, predictions)
in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
top_5_accuracy = tf.metrics.mean(in_top_5)
return {
'top_1_accuracy': top_1_accuracy,
'top_5_accuracy': top_5_accuracy,
}
eval_metrics = (metric_fn, [labels, logits])
# Prepares scaffold_fn if needed.
scaffold_fn = None
restore_vars_dict = None
if not is_training and has_moving_average_decay:
# Load moving average variables for eval.
restore_vars_dict = ema.variables_to_restore(ema_vars)
def eval_scaffold():
saver = tf.train.Saver(restore_vars_dict)
return tf.train.Scaffold(saver=saver)
scaffold_fn = eval_scaffold
return tf.estimator.tpu.TPUEstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op,
host_call=host_call,
eval_metrics=eval_metrics,
scaffold_fn=scaffold_fn)