def __call__()

in tensorflow_model_remediation/min_diff/losses/base_loss.py [0:0]


  def __call__(self,
               membership: types.TensorType,
               predictions: types.TensorType,
               sample_weight: Optional[types.TensorType] = None):
    """Invokes the `MinDiffLoss` instance.

    Args:
      membership: Labels indicating whether examples are part of the sensitive
        group. Shape must be `[batch_size, d0, .. dN]`.
      predictions: Predicted values. Must be the same shape as membership.
      sample_weight: (Optional) acts as a coefficient for the loss. Must be of
        shape [batch_size] or [batch_size, 1].  If None then a tensor of ones
        with the appropriate shape is used.

    Returns:
      Scalar `min_diff_loss`.
    """
    with tf.name_scope(self.name + '_inputs'):
      loss = self.call(membership, predictions, sample_weight)

      # Calculate metrics.
      weights = (
          sample_weight
          if sample_weight is not None else tf.ones_like(membership))
      num_min_diff_examples = tf.math.count_nonzero(weights)
      num_sensitive_group_min_diff_examples = tf.math.count_nonzero(weights *
                                                                    membership)
      num_non_sensitive_group_min_diff_examples = (
          num_min_diff_examples - num_sensitive_group_min_diff_examples)
      tf.summary.scalar('sensitive_group_min_diff_examples_count',
                        num_sensitive_group_min_diff_examples)
      tf.summary.scalar('non-sensitive_group_min_diff_examples_count',
                        num_non_sensitive_group_min_diff_examples)
      tf.summary.scalar('min_diff_examples_count', num_min_diff_examples)
      # The following metric can capture when the model degenerates and all
      # predictions go towards zero or one.
      tf.summary.scalar(
          'min_diff_average_prediction',
          tf.math.divide_no_nan(
              tf.reduce_sum(tf.dtypes.cast(weights, tf.float32) * predictions),
              tf.cast(num_min_diff_examples, dtype=tf.float32)))

      # Plot histogram of the MinDiff predictions.
      summary_histogram = (
          tf.summary.histogram
          if tf.executing_eagerly() else tf.compat.v1.summary.histogram)

      summary_histogram('min_diff_prediction_histogram', predictions)

      # Plot histogram of the MinDiff predictions for each membership class.
      # Pick out only min_diff head training data
      pos_mask = tf.dtypes.cast(weights, tf.float32) * tf.cast(
          tf.equal(membership, 1.0), tf.float32)
      neg_mask = tf.dtypes.cast(weights, tf.float32) * tf.cast(
          tf.equal(membership, 0.0), tf.float32)

      if predictions.shape.dims:
        sensitive_group_predictions = tf.squeeze(
            tf.gather(predictions, indices=tf.where(pos_mask[:, 0])))
        non_sensitive_group_predictions = tf.squeeze(
            tf.gather(predictions, indices=tf.where(neg_mask[:, 0])))

        summary_histogram('min_diff_sensitive_group_prediction_histogram',
                          sensitive_group_predictions)
        summary_histogram('min_diff_non-sensitive_group_prediction_histogram',
                          non_sensitive_group_predictions)

      return loss