in tensorflow_model_remediation/min_diff/keras/models/min_diff_model.py [0:0]
def _compute_single_min_diff_loss(self,
min_diff_data,
loss,
loss_weight,
min_diff_loss_metric,
training=None,
mask=None):
"""Computes a single `min_diff_loss` given a loss, weight, and data.
This will be called for each application of MinDiff. See
`MinDiffModel.compute_min_diff_loss` for details.
"""
x, membership, sample_weight = (
tf.keras.utils.unpack_x_y_sample_weight(min_diff_data))
predictions = self._call_original_model(x, training=training, mask=mask)
# Clear any losses added when calling the original model on the MinDiff
# examples. The right losses, if any, will be added when the original_model
# is called on the original inputs.
self._clear_losses()
predictions = self.predictions_transform(predictions)
if not isinstance(predictions, tf.Tensor):
err_msg = (
"MinDiff `predictions` meant for calculating the `min_diff_loss` "
"must be a Tensor, given: {}\n".format(predictions))
if self._predictions_transform is None:
err_msg += (
"This is due to the fact that `original_model` does not return "
"a Tensor either because it is multi output or because it has some "
"custom implementation. To handle this, pass in a "
"`predictions_transform` that converts the result into the tensor "
"the `min_diff_loss` should be calculated on.")
else:
err_msg += ("This is due to the fact that the provided "
"`predictions_transform` parameter does not return a "
"Tensor when given the output of `original_model`.")
err_msg += "\nSee `MinDiffModel` for additional documentation."
raise ValueError(err_msg)
min_diff_loss = loss_weight * loss(
predictions=predictions,
membership=membership,
sample_weight=sample_weight)
min_diff_loss_metric.update_state(min_diff_loss)
return min_diff_loss