in models/experimental/inception/inception_v4.py [0:0]
def inception_model_fn(features, labels, mode, params):
"""Inception v4 model using Estimator API."""
num_classes = FLAGS.num_classes
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
is_eval = (mode == tf.estimator.ModeKeys.EVAL)
if isinstance(features, dict):
features = features['feature']
features = tensor_transform_fn(features, params['model_transpose_dims'])
# This nested function allows us to avoid duplicating the logic which
# builds the network, for different values of --precision.
def build_network():
if FLAGS.precision == 'bfloat16':
with contrib_tpu.bfloat16_scope():
logits, end_points = inception.inception_v4(
features,
num_classes,
is_training=is_training)
logits = tf.cast(logits, tf.float32)
elif FLAGS.precision == 'float32':
logits, end_points = inception.inception_v4(
features,
num_classes,
is_training=is_training)
return logits, end_points
if FLAGS.clear_update_collections:
with arg_scope(inception.inception_v4_arg_scope(
weight_decay=0.0,
batch_norm_decay=BATCH_NORM_DECAY,
batch_norm_epsilon=BATCH_NORM_EPSILON,
updates_collections=None)):
logits, end_points = build_network()
else:
with arg_scope(inception.inception_v4_arg_scope(
batch_norm_decay=BATCH_NORM_DECAY,
batch_norm_epsilon=BATCH_NORM_EPSILON)):
logits, end_points = build_network()
predictions = {
'classes': tf.argmax(input=logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
export_outputs={
'classify': tf.estimator.export.PredictOutput(predictions)
})
if mode == tf.estimator.ModeKeys.EVAL and FLAGS.display_tensors and (
not FLAGS.use_tpu):
with tf.control_dependencies([
tf.Print(
predictions['classes'], [predictions['classes']],
summarize=FLAGS.eval_batch_size,
message='prediction: ')
]):
labels = tf.Print(
labels, [labels], summarize=FLAGS.eval_batch_size, message='label: ')
one_hot_labels = tf.one_hot(labels, FLAGS.num_classes, dtype=tf.int32)
if 'AuxLogits' in end_points:
tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels,
logits=tf.cast(end_points['AuxLogits'], tf.float32),
weights=0.4,
label_smoothing=0.1,
scope='aux_loss')
tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels,
logits=logits,
weights=1.0,
label_smoothing=0.1)
losses = tf.add_n(tf.losses.get_losses())
l2_loss = []
for v in tf.trainable_variables():
tf.logging.info(v.name)
if 'BatchNorm' not in v.name and 'weights' in v.name:
l2_loss.append(tf.nn.l2_loss(v))
tf.logging.info(len(l2_loss))
loss = losses + WEIGHT_DECAY * tf.add_n(l2_loss)
initial_learning_rate = FLAGS.learning_rate * FLAGS.train_batch_size / 256
# Adjust the initial learning rate for warmup
initial_learning_rate /= (
FLAGS.learning_rate_decay**((FLAGS.warmup_epochs + FLAGS.cold_epochs) /
FLAGS.learning_rate_decay_epochs))
final_learning_rate = 0.0001 * initial_learning_rate
host_call = None
train_op = None
if is_training:
batches_per_epoch = _NUM_TRAIN_IMAGES / FLAGS.train_batch_size
global_step = tf.train.get_or_create_global_step()
current_epoch = tf.cast(
(tf.cast(global_step, tf.float32) / batches_per_epoch), tf.int32)
clr = FLAGS.cold_learning_rate
wlr = initial_learning_rate / (FLAGS.warmup_epochs + FLAGS.cold_epochs)
learning_rate = tf.where(
tf.greater_equal(current_epoch, FLAGS.cold_epochs), (tf.where(
tf.greater_equal(current_epoch,
FLAGS.warmup_epochs + FLAGS.cold_epochs),
tf.train.exponential_decay(
learning_rate=initial_learning_rate,
global_step=global_step,
decay_steps=int(FLAGS.learning_rate_decay_epochs *
batches_per_epoch),
decay_rate=FLAGS.learning_rate_decay,
staircase=True), tf.multiply(
tf.cast(current_epoch, tf.float32), wlr))), clr)
# Set a minimum boundary for the learning rate.
learning_rate = tf.maximum(
learning_rate, final_learning_rate, name='learning_rate')
if FLAGS.optimizer == 'sgd':
tf.logging.info('Using SGD optimizer')
optimizer = tf.train.GradientDescentOptimizer(
learning_rate=learning_rate)
elif FLAGS.optimizer == 'momentum':
tf.logging.info('Using Momentum optimizer')
optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate, momentum=0.9)
elif FLAGS.optimizer == 'RMS':
tf.logging.info('Using RMS optimizer')
optimizer = tf.train.RMSPropOptimizer(
learning_rate,
RMSPROP_DECAY,
momentum=RMSPROP_MOMENTUM,
epsilon=RMSPROP_EPSILON)
else:
tf.logging.fatal('Unknown optimizer:', FLAGS.optimizer)
if FLAGS.use_tpu:
optimizer = contrib_tpu.CrossShardOptimizer(optimizer)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss, global_step=global_step)
if FLAGS.moving_average:
ema = tf.train.ExponentialMovingAverage(
decay=MOVING_AVERAGE_DECAY, num_updates=global_step)
variables_to_average = (
tf.trainable_variables() + tf.moving_average_variables())
with tf.control_dependencies([train_op]), tf.name_scope('moving_average'):
train_op = ema.apply(variables_to_average)
# To log the loss, current learning rate, and epoch for Tensorboard, the
# summary op needs to be run on the host CPU via host_call. host_call
# expects [batch_size, ...] Tensors, thus reshape to introduce a batch
# dimension. These Tensors are implicitly concatenated to
# [params['batch_size']].
gs_t = tf.reshape(global_step, [1])
loss_t = tf.reshape(loss, [1])
lr_t = tf.reshape(learning_rate, [1])
ce_t = tf.reshape(current_epoch, [1])
if not FLAGS.skip_host_call:
def host_call_fn(gs, loss, lr, ce):
"""Training host call. Creates scalar summaries for training metrics.
This function is executed on the CPU and should not directly reference
any Tensors in the rest of the `model_fn`. To pass Tensors from the
model to the `metric_fn`, provide as part of the `host_call`. See
https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
for more information.
Arguments should match the list of `Tensor` objects passed as the second
element in the tuple passed to `host_call`.
Args:
gs: `Tensor with shape `[batch]` for the global_step
loss: `Tensor` with shape `[batch]` for the training loss.
lr: `Tensor` with shape `[batch]` for the learning_rate.
ce: `Tensor` with shape `[batch]` for the current_epoch.
Returns:
List of summary ops to run on the CPU host.
"""
gs = gs[0]
with summary.create_file_writer(FLAGS.model_dir).as_default():
with summary.always_record_summaries():
summary.scalar('loss', tf.reduce_mean(loss), step=gs)
summary.scalar('learning_rate', tf.reduce_mean(lr), step=gs)
summary.scalar('current_epoch', tf.reduce_mean(ce), step=gs)
return summary.all_summary_ops()
host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t])
eval_metrics = None
if is_eval:
def metric_fn(labels, logits):
"""Evaluation metric function. Evaluates accuracy.
This function is executed on the CPU and should not directly reference
any Tensors in the rest of the `model_fn`. To pass Tensors from the model
to the `metric_fn`, provide as part of the `eval_metrics`. See
https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
for more information.
Arguments should match the list of `Tensor` objects passed as the second
element in the tuple passed to `eval_metrics`.
Args:
labels: `Tensor` with shape `[batch, ]`.
logits: `Tensor` with shape `[batch, num_classes]`.
Returns:
A dict of the metrics to return from evaluation.
"""
predictions = tf.argmax(logits, axis=1)
top_1_accuracy = tf.metrics.accuracy(labels, predictions)
in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
top_5_accuracy = tf.metrics.mean(in_top_5)
return {
'accuracy': top_1_accuracy,
'accuracy@5': top_5_accuracy,
}
eval_metrics = (metric_fn, [labels, logits])
return contrib_tpu.TPUEstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op,
host_call=host_call,
eval_metrics=eval_metrics)