in tensorflow_model_remediation/min_diff/keras/models/min_diff_model.py [0:0]
def __init__(self,
original_model: tf.keras.Model,
loss,
loss_weight=1.0,
predictions_transform=None,
**kwargs):
"""Initializes a MinDiffModel instance.
Raises:
ValueError: If `predictions_transform` is passed in but not callable.
"""
# Roundabout way of accessing the Functional class.
functional_class = tf.keras.Sequential.__bases__[0]
# We need to handle a special case where a custom MinDiffModel class is
# created that is also a subclass of the Functional class. In this case, we
# need to make sure that args match what the Functional.__init__ requires
# (i.e. `inputs` and `outputs` args) and that the rest of the
# Functional.__init__ method is skipped (supported by passing in
# `skip_init=True`).
# This requires any __init__ methods to not do input validation and to
# pass through `skip_init`.
if (isinstance(self, functional_class) and
not isinstance(self, tf.keras.Sequential)):
try:
super(MinDiffModel, self).__init__(
inputs=None, outputs=None, skip_init=True, **kwargs)
tf.keras.Model.__init__(self, **kwargs)
except Exception as e:
raise type(e)(
"There was a problem initializing the MinDiffModel subclass "
"instance. This was likely caused by:\n"
" - The kwargs that were passed in were not valid according to "
"tf.keras.Model or a base of your custom Model.\n"
" - Some args validation or requirement in your custom Model "
"__init__ method is too strict.\n"
" - Your Model subclass is not passing through **kwargs (in "
"particular `skip_init`) to the super().__init__ invocation.\n"
"To fix this, either fix the args, loosen the requirements, or "
"make sure to pass **kwargs to calls with super. If this is not "
"possible, you may need to integrate MinDiff without using "
"MinDiffModel.\n"
"Error raised: {}".format(e))
else:
try:
super(MinDiffModel, self).__init__(**kwargs)
except Exception as e:
raise type(e)(
"There was a problem initializing the MinDiffModel instance. "
"This was likely caused by the kwargs that were passed in not "
"being valid according to tf.keras.Model.\n"
"Error raised: {}".format(e))
# Set _auto_track_sub_layers to true to ensure we track the
# original_model and MinDiff layers.
self._auto_track_sub_layers = True # Track sub layers.
self.built = True # This Model is built, original_model may or may not be.
# Masking, if any, is taken care of by original_model.
self._supports_masking = False
# Clear input_spec in case there is one. We cannot make any strong
# assertions because `min_diff_data` may or may not be included and can
# have different shapes since weight is optional.
self.input_spec = None
self._original_model = original_model
structure_utils.validate_min_diff_structure(loss, struct_name="loss")
self._loss = tf.nest.map_structure(loss_utils._get_loss, loss)
structure_utils.validate_min_diff_structure(
loss_weight, struct_name="loss_weight")
self._loss_weight = _conform_weights_to_losses(
self._loss, loss_weight, default_value=1.0)
self._min_diff_loss_metric = _create_unique_metrics(self._loss,
self.metrics)
if (predictions_transform is not None and
not callable(predictions_transform)):
raise ValueError("`predictions_transform` must be callable if passed "
"in, given: {}".format(predictions_transform))
self._predictions_transform = predictions_transform