def __init__()

in tensorflow_model_remediation/min_diff/keras/models/min_diff_model.py [0:0]


  def __init__(self,
               original_model: tf.keras.Model,
               loss,
               loss_weight=1.0,
               predictions_transform=None,
               **kwargs):

    """Initializes a MinDiffModel instance.

    Raises:
      ValueError: If `predictions_transform` is passed in but not callable.
    """
    # Roundabout way of accessing the Functional class.
    functional_class = tf.keras.Sequential.__bases__[0]
    # We need to handle a special case where a custom MinDiffModel class is
    # created that is also a subclass of the Functional class. In this case, we
    # need to make sure that args match what the Functional.__init__ requires
    # (i.e. `inputs` and `outputs` args) and that the rest of the
    # Functional.__init__ method is skipped (supported by passing in
    # `skip_init=True`).
    # This requires any __init__ methods to not do input validation and to
    # pass through `skip_init`.
    if (isinstance(self, functional_class) and
        not isinstance(self, tf.keras.Sequential)):
      try:
        super(MinDiffModel, self).__init__(
            inputs=None, outputs=None, skip_init=True, **kwargs)
        tf.keras.Model.__init__(self, **kwargs)
      except Exception as e:
        raise type(e)(
            "There was a problem initializing the MinDiffModel subclass "
            "instance. This was likely caused by:\n"
            "  - The kwargs that were passed in were not valid according to "
            "tf.keras.Model or a base of your custom Model.\n"
            "  - Some args validation or requirement in your custom Model "
            "__init__ method is too strict.\n"
            "  - Your Model subclass is not passing through **kwargs (in "
            "particular `skip_init`) to the super().__init__ invocation.\n"
            "To fix this, either fix the args, loosen the requirements, or "
            "make sure to pass **kwargs to calls with super. If this is not "
            "possible, you may need to integrate MinDiff without using "
            "MinDiffModel.\n"
            "Error raised: {}".format(e))
    else:
      try:
        super(MinDiffModel, self).__init__(**kwargs)
      except Exception as e:
        raise type(e)(
            "There was a problem initializing the MinDiffModel instance. "
            "This was likely caused by the kwargs that were passed in not "
            "being valid according to tf.keras.Model.\n"
            "Error raised: {}".format(e))

    # Set _auto_track_sub_layers to true to ensure we track the
    # original_model and MinDiff layers.
    self._auto_track_sub_layers = True  # Track sub layers.
    self.built = True  # This Model is built, original_model may or may not be.
    # Masking, if any, is taken care of by original_model.
    self._supports_masking = False
    # Clear input_spec in case there is one. We cannot make any strong
    # assertions because `min_diff_data` may or may not be included and can
    # have different shapes since weight is optional.
    self.input_spec = None

    self._original_model = original_model
    structure_utils.validate_min_diff_structure(loss, struct_name="loss")
    self._loss = tf.nest.map_structure(loss_utils._get_loss, loss)
    structure_utils.validate_min_diff_structure(
        loss_weight, struct_name="loss_weight")
    self._loss_weight = _conform_weights_to_losses(
        self._loss, loss_weight, default_value=1.0)
    self._min_diff_loss_metric = _create_unique_metrics(self._loss,
                                                        self.metrics)

    if (predictions_transform is not None and
        not callable(predictions_transform)):
      raise ValueError("`predictions_transform` must be callable if passed "
                       "in, given: {}".format(predictions_transform))
    self._predictions_transform = predictions_transform