def _calculate_mutual_information_for_feature_value()

in tensorflow_transform/beam/analyzer_impls.py [0:0]


def _calculate_mutual_information_for_feature_value(feature_and_accumulator,
                                                    global_accumulator,
                                                    use_adjusted_mutual_info,
                                                    min_diff_from_avg):
  """Calculates the (possibly adjusted) mutual information of a feature value.

  Used as a measure of relatedness between a single feature value and a label.

  Mutual information is calculated as:
  H(x, y) = (sum(weights) *
             [(P(y|x)*log2(P(y|x)/P(y))) + (P(~y|x)*log2(P(~y|x)/P(~y)))])
  where x is feature and y is label. We use sum(weights) instead of P(x), as
  this makes the mutual information more interpretable.
  If we don't divide by sum(weights), it can be thought of as an adjusted
  weighted count.

  If use_adjusted_mutual_info is True, we use Adjusted Mutual Information (AMI)
  which accounts for relatedness due to chance. AMI is generally calculated as:
  AMI(x, y) = MI(x, y) - EMI(x, y) / (max(H(x), H(y)) - EMI(x, y))
  where x is the feature and y is label. Here, we leave off the normalization
  and only subtract expected mutual information (EMI) from mutual information.
  The calculation is based on the following paper:

  Vinh, N. X.; Epps, J.; Bailey, J. (2009). "Information theoretic measures for
  clusterings comparison". Proceedings of the 26th Annual International Confere
  nce on Machine Learning - ICML '09. p. 1.
  doi:10.1145/1553374.1553511. ISBN 9781605585161.

  Short summary can be found in the Wikipedia link:
  https://en.wikipedia.org/wiki/Adjusted_mutual_information

  Args:
    feature_and_accumulator: A tuple of the form:
      (feature, WeightedMeanAndVarCombiner.accumulator_class) where: `feature`
        is the single token in the vocabulary for which (possibly adjusted)
        mutual information with the label is being computed. `mean` is the
        weighted mean positive for each label value given x. `count` is the
        count of weights for a feature. `weight` is the mean of the weights for
        a feature.
    global_accumulator: A WeightedMeanAndVarCombiner.accumulator_class where:
      `mean` is the weighted mean positive for each label value for all
      features. `count` is the count for all features. `weight` is the mean of
      the weights for all features.
    use_adjusted_mutual_info: If set to True, use adjusted mutual information.
    min_diff_from_avg: A regularization parameter that pushes low MI/AMI towards
      zero. The Mutual information of a feature x label pair will be adjusted to
      zero whenever the absolute difference the weight and the expected
      (average) weight is lower than min_diff_from_average.

  Returns:
    A tuple of:
      The feature value
      The mutual information with the label. If use_adjusted_mutual_info, this
        is the mutual information - the expected mutual information, otherwise
        it is the raw mutual information.
      The expected mutual information (EMI) if use_adjusted_mutual_info is
        True, otherwise NaN.
      The total weighted sum for the feature value.
  """
  # Compute the frequency of each label value.
  global_label_counts = (
      global_accumulator.mean * global_accumulator.weight *
      global_accumulator.count)
  feature_value, current_accumulator = feature_and_accumulator
  total_label_counts = sum(global_label_counts)
  n = global_accumulator.count * global_accumulator.weight
  # TODO(b/168469757): Consider raising here once b/168469757 is resolved.
  if round(total_label_counts) != round(n):
    logging.warn(
        'Weighted label sum (%s) != total weighted count (%s), label means=%s',
        total_label_counts, n, global_accumulator.mean)
  if n == 0:
    return (feature_value, (float('NaN'), float('NaN'), 0))

  mutual_information = 0
  expected_mutual_information = 0 if use_adjusted_mutual_info else None
  x_i = (current_accumulator.count * current_accumulator.weight)
  # If x_i == n, the feature is a constant and thus has no information.
  if round(x_i) == round(n):
    return feature_value, (0, 0, x_i)
  if round(x_i) > round(n):
    raise ValueError(
        'Frequency of token {} higher than number of records {} > {}'.format(
            feature_value, x_i, n) +
        ' This likely means you have provided tft.vocabulary with input that'
        ' has repeated tokens per row, rather than a set representation.')
  for label_ix in range(len(global_label_counts)):
    y_i = global_label_counts[label_ix]
    if y_i == 0:
      continue
    local_mean = 0
    if label_ix < len(current_accumulator.mean):
      local_mean = current_accumulator.mean[label_ix]
    n_i = (
        _clip_probability(local_mean) * current_accumulator.weight *
        current_accumulator.count)
    diff_from_avg = (x_i * y_i / n) - n_i
    if abs(diff_from_avg) < min_diff_from_avg:
      continue
    mutual_information += (
        info_theory.calculate_partial_mutual_information(n_i, x_i, y_i, n))
    if use_adjusted_mutual_info:
      expected_mutual_information += (
          info_theory.calculate_partial_expected_mutual_information(
              n, x_i, y_i))

  if use_adjusted_mutual_info:
    # TODO(b/127366670): Consider implementing the normalization step as per
    # AMI(x, y) = MI(x, y) - EMI(x, y) / (max(H(x), H(y)) - EMI(x, y))
    return (feature_value, (mutual_information - expected_mutual_information,
                            expected_mutual_information, x_i))
  else:
    return (feature_value, (mutual_information, float('NaN'), x_i))