def _conform_weights_to_losses()

in tensorflow_model_remediation/min_diff/keras/models/min_diff_model.py [0:0]


def _conform_weights_to_losses(loss, loss_weight, default_value):
  """Conforms weights to match structure of losses.

  Shape weights to match the structure of `loss` if possible. If `loss_weight`
  is a single value, it will be broadcast for all losses. If `loss_weight` is
  `None` or has missing entries, `default_value` will be used.

  Args:
    loss: loss (possible nested) that weights will be conformed to.
    loss_weight: weight that will be conformed to loss structure. If only a
      single value, it will be broadcast for all losses. If `None`, it will be
      replaced by `default_value`.
    default_value: Value used if `loss_weight` is `None` or if some weights are
      missing for certain losses.

  Returns:
    Weight corresponding to `loss` structure.
  """
  # Validate loss (loss_weights will be implicitly validated)
  structure_utils.validate_min_diff_structure(loss, struct_name="loss")

  # If loss_weight is unnested, then broadcast to all values of loss.
  if not tf.nest.is_nested(loss_weight):
    if loss_weight is None:
      loss_weight = default_value
    return tf.nest.map_structure(lambda _: loss_weight, loss)

  # If execution reaches here, then loss_weight is nested (a dict).

  # If loss is not nested, then raise an error (since loss_weight is a nested).
  if not tf.nest.is_nested(loss):
    try:
      tf.nest.assert_same_structure(loss, loss_weight)
    except Exception as e:
      raise ValueError("`loss` and `loss_weight` do not have matching "
                       "structures: \n{}".format(e))

  # At this point, we should be guaranteed that the two structures are dicts if
  # they are valid MinDiff structures. However, in case they are not, we assert
  # that they are both dicts (this also helps be future proof since it will
  # catch the broken assumption immediately if the validity definition changes).
  # Note: As is, it should be impossible to get to this point. The only way it
  #       would is if this function is called without validating or if the
  #       definition of a valid MinDiff structure has changed.
  if not (isinstance(loss, dict) and isinstance(loss_weight, dict)):
    raise ValueError(
        "One of `loss` and `loss_weight` is neither a single element nor a "
        "dict. This should never happen if they are valid MinDiff structures. "
        "If you think this is a valid use case (e.g. if the definition has "
        "changed but this piece of code is out of sync), please file an issue "
        "so we can look at it and make the appropriate fix.")

  # Save copy to not alter the original dict.
  loss_weight = loss_weight.copy()

  # First, we make sure to set defaults for any losses that do not have
  # corresponding weights. Raise an error if there are weights with keys that
  # don't correspond to losses.
  if not set(loss_weight.keys()) <= set(loss.keys()):
    raise ValueError(
        "`loss_weight` contains keys that do not correspond to losses:"
        "\n\nloss: {}\n\nloss_weight: {}".format(loss, loss_weight))

  # Provide defaults for any missing weights.
  for key in loss.keys():
    if key not in loss_weight:
      loss_weight[key] = default_value

  # At this point, we should be guaranteed that the two structures match if they
  # are valid MinDiff structures. However, in case they are not we assert that
  # they match.
  try:
    tf.nest.assert_same_structure(loss, loss_weight)
  except Exception as e:
    raise ValueError(
        "`loss` and `loss_weight` (potentially with default weights added) "
        "do not have matching structures: \n{}".format(e))

  return loss_weight