in tensorflow_model_remediation/min_diff/keras/models/min_diff_model.py [0:0]
def _conform_weights_to_losses(loss, loss_weight, default_value):
"""Conforms weights to match structure of losses.
Shape weights to match the structure of `loss` if possible. If `loss_weight`
is a single value, it will be broadcast for all losses. If `loss_weight` is
`None` or has missing entries, `default_value` will be used.
Args:
loss: loss (possible nested) that weights will be conformed to.
loss_weight: weight that will be conformed to loss structure. If only a
single value, it will be broadcast for all losses. If `None`, it will be
replaced by `default_value`.
default_value: Value used if `loss_weight` is `None` or if some weights are
missing for certain losses.
Returns:
Weight corresponding to `loss` structure.
"""
# Validate loss (loss_weights will be implicitly validated)
structure_utils.validate_min_diff_structure(loss, struct_name="loss")
# If loss_weight is unnested, then broadcast to all values of loss.
if not tf.nest.is_nested(loss_weight):
if loss_weight is None:
loss_weight = default_value
return tf.nest.map_structure(lambda _: loss_weight, loss)
# If execution reaches here, then loss_weight is nested (a dict).
# If loss is not nested, then raise an error (since loss_weight is a nested).
if not tf.nest.is_nested(loss):
try:
tf.nest.assert_same_structure(loss, loss_weight)
except Exception as e:
raise ValueError("`loss` and `loss_weight` do not have matching "
"structures: \n{}".format(e))
# At this point, we should be guaranteed that the two structures are dicts if
# they are valid MinDiff structures. However, in case they are not, we assert
# that they are both dicts (this also helps be future proof since it will
# catch the broken assumption immediately if the validity definition changes).
# Note: As is, it should be impossible to get to this point. The only way it
# would is if this function is called without validating or if the
# definition of a valid MinDiff structure has changed.
if not (isinstance(loss, dict) and isinstance(loss_weight, dict)):
raise ValueError(
"One of `loss` and `loss_weight` is neither a single element nor a "
"dict. This should never happen if they are valid MinDiff structures. "
"If you think this is a valid use case (e.g. if the definition has "
"changed but this piece of code is out of sync), please file an issue "
"so we can look at it and make the appropriate fix.")
# Save copy to not alter the original dict.
loss_weight = loss_weight.copy()
# First, we make sure to set defaults for any losses that do not have
# corresponding weights. Raise an error if there are weights with keys that
# don't correspond to losses.
if not set(loss_weight.keys()) <= set(loss.keys()):
raise ValueError(
"`loss_weight` contains keys that do not correspond to losses:"
"\n\nloss: {}\n\nloss_weight: {}".format(loss, loss_weight))
# Provide defaults for any missing weights.
for key in loss.keys():
if key not in loss_weight:
loss_weight[key] = default_value
# At this point, we should be guaranteed that the two structures match if they
# are valid MinDiff structures. However, in case they are not we assert that
# they match.
try:
tf.nest.assert_same_structure(loss, loss_weight)
except Exception as e:
raise ValueError(
"`loss` and `loss_weight` (potentially with default weights added) "
"do not have matching structures: \n{}".format(e))
return loss_weight