in ec2-spot-sagemaker-managed-spot-training/sagemaker-custom-tensorflow/container/cifar10/cifar10.py [0:0]
def model_fn(features, labels, mode):
"""
Model function for CIFAR-10.
For more information: https://www.tensorflow.org/guide/custom_estimators#write_a_model_function
"""
inputs = features[INPUT_TENSOR_NAME]
tf.summary.image('images', inputs, max_outputs=6)
network = resnet_model.cifar10_resnet_v2_generator(RESNET_SIZE, NUM_CLASSES)
inputs = tf.reshape(inputs, [-1, HEIGHT, WIDTH, DEPTH])
logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN)
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
}
if mode == tf.estimator.ModeKeys.PREDICT:
export_outputs = {
SIGNATURE_NAME: tf.estimator.export.PredictOutput(predictions)
}
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, export_outputs=export_outputs)
# Calculate loss, which includes softmax cross entropy and L2 regularization.
cross_entropy = tf.losses.softmax_cross_entropy(
logits=logits, onehot_labels=tf.one_hot(labels, 10))
# Create a tensor named cross_entropy for logging purposes.
tf.identity(cross_entropy, name='cross_entropy')
tf.summary.scalar('cross_entropy', cross_entropy)
# Add weight decay to the loss.
loss = cross_entropy + _WEIGHT_DECAY * tf.add_n(
[tf.nn.l2_loss(v) for v in tf.trainable_variables()])
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_or_create_global_step()
# Multiply the learning rate by 0.1 at 100, 150, and 200 epochs.
boundaries = [int(_BATCHES_PER_EPOCH * epoch) for epoch in [100, 150, 200]]
values = [_INITIAL_LEARNING_RATE * decay for decay in [1, 0.1, 0.01, 0.001]]
learning_rate = tf.train.piecewise_constant(
tf.cast(global_step, tf.int32), boundaries, values)
# Create a tensor named learning_rate for logging purposes
tf.identity(learning_rate, name='learning_rate')
tf.summary.scalar('learning_rate', learning_rate)
optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate,
momentum=_MOMENTUM)
# Batch norm requires update ops to be added as a dependency to the train_op
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss, global_step)
else:
train_op = None
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op)