in tensorflow_model_remediation/min_diff/losses/base_loss.py [0:0]
def __call__(self,
membership: types.TensorType,
predictions: types.TensorType,
sample_weight: Optional[types.TensorType] = None):
"""Invokes the `MinDiffLoss` instance.
Args:
membership: Labels indicating whether examples are part of the sensitive
group. Shape must be `[batch_size, d0, .. dN]`.
predictions: Predicted values. Must be the same shape as membership.
sample_weight: (Optional) acts as a coefficient for the loss. Must be of
shape [batch_size] or [batch_size, 1]. If None then a tensor of ones
with the appropriate shape is used.
Returns:
Scalar `min_diff_loss`.
"""
with tf.name_scope(self.name + '_inputs'):
loss = self.call(membership, predictions, sample_weight)
# Calculate metrics.
weights = (
sample_weight
if sample_weight is not None else tf.ones_like(membership))
num_min_diff_examples = tf.math.count_nonzero(weights)
num_sensitive_group_min_diff_examples = tf.math.count_nonzero(weights *
membership)
num_non_sensitive_group_min_diff_examples = (
num_min_diff_examples - num_sensitive_group_min_diff_examples)
tf.summary.scalar('sensitive_group_min_diff_examples_count',
num_sensitive_group_min_diff_examples)
tf.summary.scalar('non-sensitive_group_min_diff_examples_count',
num_non_sensitive_group_min_diff_examples)
tf.summary.scalar('min_diff_examples_count', num_min_diff_examples)
# The following metric can capture when the model degenerates and all
# predictions go towards zero or one.
tf.summary.scalar(
'min_diff_average_prediction',
tf.math.divide_no_nan(
tf.reduce_sum(tf.dtypes.cast(weights, tf.float32) * predictions),
tf.cast(num_min_diff_examples, dtype=tf.float32)))
# Plot histogram of the MinDiff predictions.
summary_histogram = (
tf.summary.histogram
if tf.executing_eagerly() else tf.compat.v1.summary.histogram)
summary_histogram('min_diff_prediction_histogram', predictions)
# Plot histogram of the MinDiff predictions for each membership class.
# Pick out only min_diff head training data
pos_mask = tf.dtypes.cast(weights, tf.float32) * tf.cast(
tf.equal(membership, 1.0), tf.float32)
neg_mask = tf.dtypes.cast(weights, tf.float32) * tf.cast(
tf.equal(membership, 0.0), tf.float32)
if predictions.shape.dims:
sensitive_group_predictions = tf.squeeze(
tf.gather(predictions, indices=tf.where(pos_mask[:, 0])))
non_sensitive_group_predictions = tf.squeeze(
tf.gather(predictions, indices=tf.where(neg_mask[:, 0])))
summary_histogram('min_diff_sensitive_group_prediction_histogram',
sensitive_group_predictions)
summary_histogram('min_diff_non-sensitive_group_prediction_histogram',
non_sensitive_group_predictions)
return loss