in tensorflow_model_remediation/min_diff/keras/models/min_diff_model.py [0:0]
def _call_original_model(self, inputs, training=None, mask=None):
"""Calls the original model with appropriate args."""
arg_tuples = [("training", training,
self.original_model._expects_training_arg)]
# Check if the original model call signature uses "mask" and pass mask to
# the original model if present.
if "mask" in inspect.getfullargspec((self.original_model.call)).args:
arg_tuples.append(("mask", mask, self.original_model._expects_mask_arg))
kwargs = {name: value for name, value, expected in arg_tuples if expected}
return self.original_model(inputs, **kwargs)