def compute_min_diff_loss()

in tensorflow_model_remediation/min_diff/keras/models/min_diff_model.py [0:0]


  def compute_min_diff_loss(self, min_diff_data, training=None, mask=None):
    # pyformat: disable
    """Computes `min_diff_loss`(es) corresponding to `min_diff_data`.

    Arguments:
      min_diff_data: Tuple of data or valid MinDiff structure of tuples as
        described below.
      training: Boolean indicating whether to run in training or inference mode.
        See `tf.keras.Model.call` for details.
      mask: Mask or list of masks as described in `tf.keras.Model.call`. These
        will be applied when calling the `original_model`.


    `min_diff_data` must have a structure (or be a single element) matching that
    of the `loss` parameter passed in during initialization. Each element of
    `min_diff_data` (and `loss`) corresponds to one application of MinDiff.

    Like the input requirements described in `tf.keras.Model.fit`, each element
    of `min_diff_data` must be a tuple of length 2 or 3. The tuple will be
    unpacked using the standard `tf.keras.utils.unpack_x_y_sample_weight`
    function:

    ```
    min_diff_data_elem = ...  # Single element from a batch of min_diff_data.

    min_diff_x, min_diff_membership, min_diff_sample_weight = (
        tf.keras.utils.unpack_x_y_sample_weight(min_diff_data_elem))
    ```
    The components are defined as follows:

    - `min_diff_x`: inputs to `original_model` to get the corresponding MinDiff
      predictions.
    - `min_diff_membership`: numerical [batch_size, 1] `Tensor` indicating which
      group each example comes from (marked as `0.0` or `1.0`).
    - `min_diff_sample_weight`: Optional weight `Tensor`. The weights will be
      applied to the examples during the `min_diff_loss` calculation.

    For each application of MinDiff, the `min_diff_loss` is ultimately
    calculated from the MinDiff predictions which are evaluated in the
    following way:

    ```
    ...  # In compute_min_diff_loss call.

    min_diff_x = ...  # Single batch of MinDiff examples.

    # Get predictions for MinDiff examples.
    min_diff_predictions = self.original_model(min_diff_x, training=training)
    # Transform the predictions if needed. By default this is the identity.
    min_diff_predictions = self.predictions_transform(min_diff_predictions)
    ```

    Returns:
      Scalar (if only one) or list of `min_diff_loss` values calculated from
        `min_diff_data`.

    Raises:
      ValueError: If the structure of `min_diff_data` does not match that of the
        `loss` that was passed to the model during initialization.
      ValueError: If the transformed `min_diff_predictions` is not a
        `tf.Tensor`.
    """
    # pyformat: enable

    structure_utils._assert_same_min_diff_structure(min_diff_data, self._loss)

    # Flatten everything and calculate min_diff_loss for each application.
    flat_data = structure_utils._flatten_min_diff_structure(min_diff_data)
    flat_losses = structure_utils._flatten_min_diff_structure(self._loss)
    flat_weights = structure_utils._flatten_min_diff_structure(
        self._loss_weight)
    flat_metrics = structure_utils._flatten_min_diff_structure(
        self._min_diff_loss_metric)
    min_diff_losses = [
        self._compute_single_min_diff_loss(data, loss, weight, metric, training,
                                           mask) for data, loss, weight, metric
        in zip(flat_data, flat_losses, flat_weights, flat_metrics)
    ]
    # If there is only one application return a scalar rather than a list.
    if len(min_diff_losses) == 1:
      min_diff_losses = min_diff_losses[0]
    return min_diff_losses