in kfac/python/keras/utils.py [0:0]
def get_loss_fn(model,
loss,
training=None,
loss_weights=None,
reduce_fn=tf.reduce_mean,
name='loss'):
"""Creates a loss function to be used for KFAC's adaptive damping.
This allows Keras KFAC to automatically create the loss function to use
for adaptive_damping. This function would also be useful for a custom training
loop that uses adaptive_damping.
The returned loss function currently does not support masks or sample_weights.
Currently, if you use a categorical crossentropy loss, due to the
implementation of tf.keras.losses.*_crossentropy, it will grab the logits
whether you use a softmax at the end of your model or not. This is true as of
August 1, 2019. Code below:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/backend.py#L4322
Args:
model: tf.keras.Model model that will be used with the inputs to the
returned loss_fn.
loss: Potentially serialized tf.keras.losses.* loss function(s)/class(es).
If the model has multiple outputs, this must be a list of losses that
matches the order of model.outputs, or a dictionary with names matching
output_names. Must accept kwargs y_pred and y_true. Note that if your
model's output are logits, you should pass a callable Keras with
from_logits=True. This function could be a non-Keras loss, but it is
untested in this case.
training: Boolean indicating whether or not the loss is used in training or
test time. This is necessary to set the proper mode for batch norm and
dropout layers. If None then falls back to Keras behavior of calling the
model without passing a value for training.
loss_weights: If you have multiple losses, a list or dictionaryof weights
for each loss. A default value of 1.0 is given for losses that don't have
a weight when a dictionary is passed.
reduce_fn: The function that will be used to aggregate the loss tensor.
tf.reduce_mean by default. You may replace this with the identity if your
loss does a reduction by default. Depending on how you compute your loss
in a distributed setting, you may want to modify this function (for
example, if you sum across replicas, then the reduce_fn might be
lambda x: tf.reduce_sum(x) * (1.0 / global_batch_size).
name: Name scope for the loss_fn ops.
Raises:
ValueError: If the loss is a dictionary.
Returns:
A function that takes inputs and optionally a prediction and will return
a loss. This can be used as the KFAC loss_fn for adaptive damping.
"""
if isinstance(loss, six.string_types):
loss = losses.deserialize(loss)
elif isinstance(loss, dict):
loss = [loss[n] for n in model.output_names]
if isinstance(loss, list):
loss = [losses.deserialize(l) if isinstance(l, six.string_types) else l
for l in loss]
if isinstance(loss_weights, dict):
loss_weights = [loss_weights.get(n, 1.0) for n in model.output_names]
def loss_fn(inputs, prediction=None):
"""Computes loss for a model given inputs.
This function is meant to be used with K-FAC's adaptive damping, which is
why the prediction is optional (since K-FAC wants to compute the loss just
given inputs).
Args:
inputs: A tuple with (model_input(s), label(s)), where both elements are
tensors or lists/tuples of tensors.
prediction: The output of the model given the inputs. If this isn't,
provided, the prediction will be computed via
prediction = model(inputs[0])
Returns:
A tensor with the total reduced loss including regularization and other
layer specific losses.
"""
with tf.name_scope(name):
x, y = inputs
if prediction is None:
if training is not None:
prediction = model(x, training=training)
else:
prediction = model(x)
if isinstance(prediction, (tuple, list)):
reduced_losses = [reduce_fn(fn(y_pred=pred_i, y_true=y_i))
for fn, pred_i, y_i in zip(loss, prediction, y)]
if loss_weights:
reduced_losses = [l * w for l, w in zip(reduced_losses, loss_weights)]
total_loss = tf.add_n(reduced_losses)
else:
total_loss = reduce_fn(loss(y_pred=prediction, y_true=y))
# Adds regularization penalties and other custom layer specific losses.
if model.losses:
total_loss += tf.add_n(model.losses)
return total_loss
return loss_fn