def _compute_single_min_diff_loss()

in tensorflow_model_remediation/min_diff/keras/models/min_diff_model.py [0:0]


  def _compute_single_min_diff_loss(self,
                                    min_diff_data,
                                    loss,
                                    loss_weight,
                                    min_diff_loss_metric,
                                    training=None,
                                    mask=None):

    """Computes a single `min_diff_loss` given a loss, weight, and data.

    This will be called for each application of MinDiff. See
    `MinDiffModel.compute_min_diff_loss` for details.
    """
    x, membership, sample_weight = (
        tf.keras.utils.unpack_x_y_sample_weight(min_diff_data))

    predictions = self._call_original_model(x, training=training, mask=mask)
    # Clear any losses added when calling the original model on the MinDiff
    # examples. The right losses, if any, will be added when the original_model
    # is called on the original inputs.
    self._clear_losses()

    predictions = self.predictions_transform(predictions)
    if not isinstance(predictions, tf.Tensor):
      err_msg = (
          "MinDiff `predictions` meant for calculating the `min_diff_loss` "
          "must be a Tensor, given: {}\n".format(predictions))
      if self._predictions_transform is None:
        err_msg += (
            "This is due to the fact that `original_model` does not return "
            "a Tensor either because it is multi output or because it has some "
            "custom implementation. To handle this, pass in a "
            "`predictions_transform` that converts the result into the tensor "
            "the `min_diff_loss` should be calculated on.")
      else:
        err_msg += ("This is due to the fact that the provided "
                    "`predictions_transform` parameter does not return a "
                    "Tensor when given the output of `original_model`.")
      err_msg += "\nSee `MinDiffModel` for additional documentation."

      raise ValueError(err_msg)

    min_diff_loss = loss_weight * loss(
        predictions=predictions,
        membership=membership,
        sample_weight=sample_weight)
    min_diff_loss_metric.update_state(min_diff_loss)

    return min_diff_loss