def _call_original_model()

in tensorflow_model_remediation/min_diff/keras/models/min_diff_model.py [0:0]


  def _call_original_model(self, inputs, training=None, mask=None):
    """Calls the original model with appropriate args."""

    arg_tuples = [("training", training,
                   self.original_model._expects_training_arg)]

    # Check if the original model call signature uses "mask" and pass mask to
    # the original model if present.
    if "mask" in inspect.getfullargspec((self.original_model.call)).args:
      arg_tuples.append(("mask", mask, self.original_model._expects_mask_arg))
    kwargs = {name: value for name, value, expected in arg_tuples if expected}
    return self.original_model(inputs, **kwargs)