def call()

in tensorflow_model_remediation/min_diff/losses/absolute_correlation_loss.py [0:0]


  def call(
      self,
      sensitive_group_labels: types.TensorType,
      y_pred: types.TensorType,
      sample_weight: Optional[types.TensorType] = None) -> types.TensorType:
    """Computes the absolute correlation loss value."""

    sensitive_group_labels, y_pred, normed_weights = self._preprocess_inputs(
        sensitive_group_labels, y_pred, sample_weight)

    weighted_mean_sensitive_group_labels = tf.reduce_sum(normed_weights *
                                                         sensitive_group_labels)
    weighted_mean_y_pred = tf.reduce_sum(normed_weights * y_pred)
    weighted_var_sensitive_group_labels = tf.reduce_sum(
        normed_weights * tf.square(sensitive_group_labels -
                                   weighted_mean_sensitive_group_labels))
    weighted_var_y_pred = tf.reduce_sum(
        normed_weights * tf.square(y_pred - weighted_mean_y_pred))

    weighted_covar = tf.reduce_sum(
        normed_weights *
        (sensitive_group_labels - weighted_mean_sensitive_group_labels) *
        (y_pred - weighted_mean_y_pred))

    # Epsilon is used to avoid non defined gradients.
    corr = tf.math.divide_no_nan(
        weighted_covar,
        tf.sqrt(weighted_var_sensitive_group_labels + _EPSILON) *
        tf.sqrt(weighted_var_y_pred + _EPSILON))

    loss = tf.abs(corr)
    return loss