in tensorflow_model_remediation/min_diff/keras/models/min_diff_model.py [0:0]
def compute_min_diff_loss(self, min_diff_data, training=None, mask=None):
# pyformat: disable
"""Computes `min_diff_loss`(es) corresponding to `min_diff_data`.
Arguments:
min_diff_data: Tuple of data or valid MinDiff structure of tuples as
described below.
training: Boolean indicating whether to run in training or inference mode.
See `tf.keras.Model.call` for details.
mask: Mask or list of masks as described in `tf.keras.Model.call`. These
will be applied when calling the `original_model`.
`min_diff_data` must have a structure (or be a single element) matching that
of the `loss` parameter passed in during initialization. Each element of
`min_diff_data` (and `loss`) corresponds to one application of MinDiff.
Like the input requirements described in `tf.keras.Model.fit`, each element
of `min_diff_data` must be a tuple of length 2 or 3. The tuple will be
unpacked using the standard `tf.keras.utils.unpack_x_y_sample_weight`
function:
```
min_diff_data_elem = ... # Single element from a batch of min_diff_data.
min_diff_x, min_diff_membership, min_diff_sample_weight = (
tf.keras.utils.unpack_x_y_sample_weight(min_diff_data_elem))
```
The components are defined as follows:
- `min_diff_x`: inputs to `original_model` to get the corresponding MinDiff
predictions.
- `min_diff_membership`: numerical [batch_size, 1] `Tensor` indicating which
group each example comes from (marked as `0.0` or `1.0`).
- `min_diff_sample_weight`: Optional weight `Tensor`. The weights will be
applied to the examples during the `min_diff_loss` calculation.
For each application of MinDiff, the `min_diff_loss` is ultimately
calculated from the MinDiff predictions which are evaluated in the
following way:
```
... # In compute_min_diff_loss call.
min_diff_x = ... # Single batch of MinDiff examples.
# Get predictions for MinDiff examples.
min_diff_predictions = self.original_model(min_diff_x, training=training)
# Transform the predictions if needed. By default this is the identity.
min_diff_predictions = self.predictions_transform(min_diff_predictions)
```
Returns:
Scalar (if only one) or list of `min_diff_loss` values calculated from
`min_diff_data`.
Raises:
ValueError: If the structure of `min_diff_data` does not match that of the
`loss` that was passed to the model during initialization.
ValueError: If the transformed `min_diff_predictions` is not a
`tf.Tensor`.
"""
# pyformat: enable
structure_utils._assert_same_min_diff_structure(min_diff_data, self._loss)
# Flatten everything and calculate min_diff_loss for each application.
flat_data = structure_utils._flatten_min_diff_structure(min_diff_data)
flat_losses = structure_utils._flatten_min_diff_structure(self._loss)
flat_weights = structure_utils._flatten_min_diff_structure(
self._loss_weight)
flat_metrics = structure_utils._flatten_min_diff_structure(
self._min_diff_loss_metric)
min_diff_losses = [
self._compute_single_min_diff_loss(data, loss, weight, metric, training,
mask) for data, loss, weight, metric
in zip(flat_data, flat_losses, flat_weights, flat_metrics)
]
# If there is only one application return a scalar rather than a list.
if len(min_diff_losses) == 1:
min_diff_losses = min_diff_losses[0]
return min_diff_losses