in tensorflow_model_remediation/min_diff/losses/absolute_correlation_loss.py [0:0]
def call(
self,
sensitive_group_labels: types.TensorType,
y_pred: types.TensorType,
sample_weight: Optional[types.TensorType] = None) -> types.TensorType:
"""Computes the absolute correlation loss value."""
sensitive_group_labels, y_pred, normed_weights = self._preprocess_inputs(
sensitive_group_labels, y_pred, sample_weight)
weighted_mean_sensitive_group_labels = tf.reduce_sum(normed_weights *
sensitive_group_labels)
weighted_mean_y_pred = tf.reduce_sum(normed_weights * y_pred)
weighted_var_sensitive_group_labels = tf.reduce_sum(
normed_weights * tf.square(sensitive_group_labels -
weighted_mean_sensitive_group_labels))
weighted_var_y_pred = tf.reduce_sum(
normed_weights * tf.square(y_pred - weighted_mean_y_pred))
weighted_covar = tf.reduce_sum(
normed_weights *
(sensitive_group_labels - weighted_mean_sensitive_group_labels) *
(y_pred - weighted_mean_y_pred))
# Epsilon is used to avoid non defined gradients.
corr = tf.math.divide_no_nan(
weighted_covar,
tf.sqrt(weighted_var_sensitive_group_labels + _EPSILON) *
tf.sqrt(weighted_var_y_pred + _EPSILON))
loss = tf.abs(corr)
return loss