in reinforcement_learning/rl_network_compression_ray_custom/src/tensorflow_resnet/compressor/resnet.py [0:0]
def builder(features, labels, mode, params):
"""Shared functionality for different resnet model_fns.
Initializes the ResnetModel representing the model layers
and uses that model to build the necessary EstimatorSpecs for
the `mode` in question. For training, this means building losses,
the optimizer, and the train op that get passed into the EstimatorSpec.
For evaluation and prediction, the EstimatorSpec is returned without
a train op, but with the necessary parameters for the given mode.
Args:
features: tensor representing input images
labels: tensor representing class labels for all input images
mode: current estimator mode; should be one of
`tf.estimator.ModeKeys.TRAIN`, `EVALUATE`, `PREDICT`
Everything else goes in the params parameter:
model_class: a class representing a TensorFlow model that has a __call__
function. We assume here that this is a subclass of ResnetModel.
learning_rate_fn: function that returns the current learning rate given
the current global_step
data_format: Input format ('channels_last', 'channels_first', or None).
If set to None, the format is dependent on whether a GPU is available.
fine_tune: If True only train the dense layers(final layers).
remove_layers: A Boolean array of layers to remove.
Returns:
EstimatorSpec parameterized according to the input params and the
current mode.
"""
model_class = params["model_class"]
data_format = params["data_format"]
num_classes = params["num_classes"]
loss_scale = params["loss_scale"]
resnet_size = params["resnet_size"]
fine_tune = params["fine_tune"]
remove_layers = params["remove_layers"]
dtype = params["dtype"]
name = params["name"]
loss_filter_fn = None
num_images = params["num_images"]
batch_size = params["batch_size"]
momentum = params["momentum"]
teacher = params["teacher"]
weights = params["weights"]
params_scope = params["params_scope"]
temperature = params["temperature"]
distillation_coefficient = params["distillation_coefficient"]
weight_decay = params["weight_decay"]
if mode == ModeKeys.REFERENCE:
fake = True
else:
fake = False
if not fake:
# Generate a summary node for the images
tf.summary.image("images", features, max_outputs=6)
# Checks that features/images have same data type being used for calculations.
assert features.dtype == dtype
model = model_class(
name=name,
resnet_size=18,
bottleneck=False,
num_classes=num_classes,
num_filters=16,
kernel_size=3,
conv_stride=1,
first_pool_size=None,
first_pool_stride=None,
block_strides=[(resnet_size - 2) // 6] * 3,
block_sizes=[1, 2, 2],
data_format=data_format,
)
training = True if mode == ModeKeys.TRAIN else False
logits = model(
inputs=features,
training=training,
remove_layers=remove_layers,
weights=weights,
fake=fake,
params_scope=params_scope,
)
if mode == ModeKeys.REFERENCE:
return model
assert not fake
if not teacher is None and not mode == ModeKeys.PREDICT:
# Build a teacher model using the same loaded weights.
teacher_model = model_class(
name=teacher,
resnet_size=18,
bottleneck=False,
num_classes=num_classes,
num_filters=16,
kernel_size=3,
conv_stride=1,
first_pool_size=None,
first_pool_stride=None,
block_strides=[(resnet_size - 2) // 6] * 3,
block_sizes=[1, 2, 2],
data_format=data_format,
)
teacher_logits = teacher_model(
inputs=features,
training=False,
remove_layers=None,
weights=weights,
params_scope=params_scope,
)
# This acts as a no-op if the logits are already in fp32 (provided logits are
# not a SparseTensor). If dtype is is low precision, logits must be cast to
# fp32 for numerical stability.
logits = tf.cast(logits, tf.float32)
if not teacher is None and not mode == ModeKeys.PREDICT:
teacher_logits = tf.cast(teacher_logits, tf.float32)
predictions = {
"classes": tf.argmax(logits, axis=1),
"probabilities": tf.nn.softmax(logits, name="softmax_tensor"),
}
if not teacher is None and not mode == ModeKeys.PREDICT:
predictions["teacher_classes"] = tf.argmax(teacher_logits, axis=1)
predictions["teacher_probabilities"] = tf.nn.softmax(
teacher_logits, name="teacher_softmax_tensor"
)
if mode == ModeKeys.PREDICT:
# Return the predictions and the specification for serving a SavedModel
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
export_outputs={"predict": tf.estimator.export.PredictOutput(predictions)},
)
# Calculate loss, which includes softmax cross entropy and L2 regularization.
cross_entropy = tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels)
if teacher is not None:
distillation_loss = ResNetXXModel.create_distillation_loss(
logits, teacher_logits, temperature
)
teacher_cross_entropy = tf.losses.sparse_softmax_cross_entropy(
logits=teacher_logits, labels=labels
)
# Create a tensor named cross_entropy and distillation loss for logging purposes.
tf.identity(cross_entropy, name="cross_entropy")
tf.summary.scalar("cross_entropy", cross_entropy)
if teacher is not None:
tf.identity(distillation_loss, name="distillation_loss")
tf.summary.scalar("distillation_loss", distillation_coefficient * distillation_loss)
tf.identity(cross_entropy, name="teacher_cross_entropy")
tf.summary.scalar("teacher_cross_entropy", teacher_cross_entropy)
# If no loss_filter_fn is passed, assume we want the default behavior,
# which is that batch_normalization variables are excluded from loss.
def exclude_batch_norm(name):
return "batch_normalization" not in name
loss_filter_fn = loss_filter_fn or exclude_batch_norm
learning_rate_fn = ResNetXXModel.learning_rate_with_decay(
base_lr=0.1,
batch_size=batch_size,
batch_denom=128,
num_images=num_images["train"],
boundary_epochs=[20, 30],
decay_rates=[0.1, 0.01, 0.001],
)
def loss_filter_fn(_):
return True
# Add weight decay to the loss.
trainable_vars = get_tf_vars_list(name)
l2_loss = weight_decay * tf.add_n(
# loss is computed using fp32 for numerical stability.
[
tf.nn.l2_loss(tf.cast(v, tf.float32))
for v in trainable_vars
if loss_filter_fn(v.name)
]
)
tf.summary.scalar("l2_loss", l2_loss)
loss = cross_entropy + l2_loss
if teacher is not None:
loss = loss + distillation_coefficient * distillation_loss
if mode == ModeKeys.TRAIN:
global_step = tf.train.get_or_create_global_step()
learning_rate = learning_rate_fn(global_step)
# Create a tensor named learning_rate for logging purposes
tf.identity(learning_rate, name="learning_rate")
tf.summary.scalar("learning_rate", learning_rate)
optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=momentum)
def _dense_grad_filter(gvs):
"""Only apply gradient updates to the final layer.
This function is used for fine tuning.
Args:
gvs: list of tuples with gradients and variable info
Returns:
filtered gradients so that only the dense layer remains
"""
return [(g, v) for g, v in gvs if "dense" in v.name]
if loss_scale != 1:
# When computing fp16 gradients, often intermediate tensor values are
# so small, they underflow to 0. To avoid this, we multiply the loss by
# loss_scale to make these tensor values loss_scale times bigger.
scaled_grad_vars = optimizer.compute_gradients(loss * loss_scale)
assert not fine_tune
if fine_tune:
scaled_grad_vars = _dense_grad_filter(scaled_grad_vars)
# Once the gradient computation is complete we can scale the gradients
# back to the correct scale before passing them to the optimizer.
unscaled_grad_vars = [(grad / loss_scale, var) for grad, var in scaled_grad_vars]
minimize_op = optimizer.apply_gradients(unscaled_grad_vars, global_step)
else:
grad_vars = optimizer.compute_gradients(loss, var_list=trainable_vars)
if fine_tune:
grad_vars = _dense_grad_filter(grad_vars)
minimize_op = optimizer.apply_gradients(grad_vars, global_step)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
train_op = tf.group(minimize_op, update_ops)
else:
train_op = None
accuracy = tf.metrics.accuracy(labels, predictions["classes"])
accuracy_top_5 = tf.metrics.mean(
tf.nn.in_top_k(predictions=logits, targets=labels, k=5, name="top_5_op")
)
if not teacher is None:
teacher_accuracy = tf.metrics.accuracy(labels, predictions["teacher_classes"])
teacher_accuracy_top_5 = tf.metrics.mean(
tf.nn.in_top_k(
predictions=teacher_logits, targets=labels, k=5, name="teacher_top_5_op"
)
)
metrics = {"accuracy": accuracy, "accuracy_top_5": accuracy_top_5}
if not teacher is None:
metrics["teacher_accuracy"] = teacher_accuracy
metrics["teacher_accuracy_top_5"] = teacher_accuracy_top_5
# Create a tensor named train_accuracy for logging purposes
tf.identity(accuracy[1], name="train_accuracy")
tf.identity(accuracy_top_5[1], name="train_accuracy_top_5")
tf.summary.scalar("train_accuracy", accuracy[1])
tf.summary.scalar("train_accuracy_top_5", accuracy_top_5[1])
if not teacher is None:
tf.summary.scalar("teacher_accuracy", teacher_accuracy[1])
tf.summary.scalar("teacher_accuracy_top_5", teacher_accuracy_top_5[1])
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metric_ops=metrics,
)