def _calculate_mean()

in tensorflow_model_remediation/min_diff/losses/mmd_loss.py [0:0]


  def _calculate_mean(self, predictions_kernel: types.TensorType,
                      normed_weights: types.TensorType,
                      pos_mask: types.TensorType, neg_mask):
    """Calculate means of groups."""
    weights_ij = tf.matmul(normed_weights, tf.transpose(normed_weights))

    pos_mean_mask = tf.matmul(pos_mask, tf.transpose(pos_mask))
    pos_mean_weights = weights_ij * pos_mean_mask
    neg_mean_mask = tf.matmul(neg_mask, tf.transpose(neg_mask))
    neg_mean_weights = weights_ij * neg_mean_mask
    pos_neg_mean_mask = tf.matmul(pos_mask, tf.transpose(neg_mask))
    pos_neg_mean_weights = weights_ij * pos_neg_mean_mask

    pos_mean = tf.math.divide_no_nan(
        tf.reduce_sum(pos_mean_weights * predictions_kernel),
        tf.reduce_sum(pos_mean_weights))
    neg_mean = tf.math.divide_no_nan(
        tf.reduce_sum(neg_mean_weights * predictions_kernel),
        tf.reduce_sum(neg_mean_weights))
    pos_neg_mean = tf.math.divide_no_nan(
        tf.reduce_sum(pos_neg_mean_weights * predictions_kernel),
        tf.reduce_sum(pos_neg_mean_weights))

    return pos_mean, neg_mean, pos_neg_mean