def get_loss_fn()

in kfac/python/keras/utils.py [0:0]


def get_loss_fn(model,
                loss,
                training=None,
                loss_weights=None,
                reduce_fn=tf.reduce_mean,
                name='loss'):
  """Creates a loss function to be used for KFAC's adaptive damping.

  This allows Keras KFAC to automatically create the loss function to use
  for adaptive_damping. This function would also be useful for a custom training
  loop that uses adaptive_damping.

  The returned loss function currently does not support masks or sample_weights.

  Currently, if you use a categorical crossentropy loss, due to the
  implementation of tf.keras.losses.*_crossentropy, it will  grab the logits
  whether you use a softmax at the end of your model or not. This is true as of
  August 1, 2019. Code below:
  https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/backend.py#L4322

  Args:
    model: tf.keras.Model model that will be used with the inputs to the
      returned loss_fn.
    loss: Potentially serialized tf.keras.losses.* loss function(s)/class(es).
      If the model has multiple outputs, this must be a list of losses that
      matches the order of model.outputs, or a dictionary with names matching
      output_names. Must accept kwargs y_pred and y_true. Note that if your
      model's output are logits, you should pass a callable Keras with
      from_logits=True. This function could be a non-Keras loss, but it is
      untested in this case.
    training: Boolean indicating whether or not the loss is used in training or
      test time. This is necessary to set the proper mode for batch norm and
      dropout layers. If None then falls back to Keras behavior of calling the
      model without passing a value for training.
    loss_weights: If you have multiple losses, a list or dictionaryof weights
      for each loss. A default value of 1.0 is given for losses that don't have
      a weight when a dictionary is passed.
    reduce_fn: The function that will be used to aggregate the loss tensor.
      tf.reduce_mean by default. You may replace this with the identity if your
      loss does a reduction by default. Depending on how you compute your loss
      in a distributed setting, you may want to modify this function (for
      example, if you sum across replicas, then the reduce_fn might be
      lambda x: tf.reduce_sum(x) * (1.0 / global_batch_size).
    name: Name scope for the loss_fn ops.

  Raises:
    ValueError: If the loss is a dictionary.

  Returns:
    A function that takes inputs and optionally a prediction and will return
    a loss. This can be used as the KFAC loss_fn for adaptive damping.
  """
  if isinstance(loss, six.string_types):
    loss = losses.deserialize(loss)
  elif isinstance(loss, dict):
    loss = [loss[n] for n in model.output_names]

  if isinstance(loss, list):
    loss = [losses.deserialize(l) if isinstance(l, six.string_types) else l
            for l in loss]

  if isinstance(loss_weights, dict):
    loss_weights = [loss_weights.get(n, 1.0) for n in model.output_names]

  def loss_fn(inputs, prediction=None):
    """Computes loss for a model given inputs.

    This function is meant to be used with K-FAC's adaptive damping, which is
    why the prediction is optional (since K-FAC wants to compute the loss just
    given inputs).

    Args:
      inputs: A tuple with (model_input(s), label(s)), where both elements are
        tensors or lists/tuples of tensors.
      prediction: The output of the model given the inputs. If this isn't,
        provided, the prediction will be computed via
        prediction = model(inputs[0])

    Returns:
      A tensor with the total reduced loss including regularization and other
      layer specific losses.
    """
    with tf.name_scope(name):
      x, y = inputs
      if prediction is None:
        if training is not None:
          prediction = model(x, training=training)
        else:
          prediction = model(x)

      if isinstance(prediction, (tuple, list)):
        reduced_losses = [reduce_fn(fn(y_pred=pred_i, y_true=y_i))
                          for fn, pred_i, y_i in zip(loss, prediction, y)]
        if loss_weights:
          reduced_losses = [l * w for l, w in zip(reduced_losses, loss_weights)]
        total_loss = tf.add_n(reduced_losses)
      else:
        total_loss = reduce_fn(loss(y_pred=prediction, y_true=y))

      # Adds regularization penalties and other custom layer specific losses.
      if model.losses:
        total_loss += tf.add_n(model.losses)

    return total_loss

  return loss_fn