in tensorflow_transform/beam/analyzer_impls.py [0:0]
def _calculate_mutual_information_for_feature_value(feature_and_accumulator,
global_accumulator,
use_adjusted_mutual_info,
min_diff_from_avg):
"""Calculates the (possibly adjusted) mutual information of a feature value.
Used as a measure of relatedness between a single feature value and a label.
Mutual information is calculated as:
H(x, y) = (sum(weights) *
[(P(y|x)*log2(P(y|x)/P(y))) + (P(~y|x)*log2(P(~y|x)/P(~y)))])
where x is feature and y is label. We use sum(weights) instead of P(x), as
this makes the mutual information more interpretable.
If we don't divide by sum(weights), it can be thought of as an adjusted
weighted count.
If use_adjusted_mutual_info is True, we use Adjusted Mutual Information (AMI)
which accounts for relatedness due to chance. AMI is generally calculated as:
AMI(x, y) = MI(x, y) - EMI(x, y) / (max(H(x), H(y)) - EMI(x, y))
where x is the feature and y is label. Here, we leave off the normalization
and only subtract expected mutual information (EMI) from mutual information.
The calculation is based on the following paper:
Vinh, N. X.; Epps, J.; Bailey, J. (2009). "Information theoretic measures for
clusterings comparison". Proceedings of the 26th Annual International Confere
nce on Machine Learning - ICML '09. p. 1.
doi:10.1145/1553374.1553511. ISBN 9781605585161.
Short summary can be found in the Wikipedia link:
https://en.wikipedia.org/wiki/Adjusted_mutual_information
Args:
feature_and_accumulator: A tuple of the form:
(feature, WeightedMeanAndVarCombiner.accumulator_class) where: `feature`
is the single token in the vocabulary for which (possibly adjusted)
mutual information with the label is being computed. `mean` is the
weighted mean positive for each label value given x. `count` is the
count of weights for a feature. `weight` is the mean of the weights for
a feature.
global_accumulator: A WeightedMeanAndVarCombiner.accumulator_class where:
`mean` is the weighted mean positive for each label value for all
features. `count` is the count for all features. `weight` is the mean of
the weights for all features.
use_adjusted_mutual_info: If set to True, use adjusted mutual information.
min_diff_from_avg: A regularization parameter that pushes low MI/AMI towards
zero. The Mutual information of a feature x label pair will be adjusted to
zero whenever the absolute difference the weight and the expected
(average) weight is lower than min_diff_from_average.
Returns:
A tuple of:
The feature value
The mutual information with the label. If use_adjusted_mutual_info, this
is the mutual information - the expected mutual information, otherwise
it is the raw mutual information.
The expected mutual information (EMI) if use_adjusted_mutual_info is
True, otherwise NaN.
The total weighted sum for the feature value.
"""
# Compute the frequency of each label value.
global_label_counts = (
global_accumulator.mean * global_accumulator.weight *
global_accumulator.count)
feature_value, current_accumulator = feature_and_accumulator
total_label_counts = sum(global_label_counts)
n = global_accumulator.count * global_accumulator.weight
# TODO(b/168469757): Consider raising here once b/168469757 is resolved.
if round(total_label_counts) != round(n):
logging.warn(
'Weighted label sum (%s) != total weighted count (%s), label means=%s',
total_label_counts, n, global_accumulator.mean)
if n == 0:
return (feature_value, (float('NaN'), float('NaN'), 0))
mutual_information = 0
expected_mutual_information = 0 if use_adjusted_mutual_info else None
x_i = (current_accumulator.count * current_accumulator.weight)
# If x_i == n, the feature is a constant and thus has no information.
if round(x_i) == round(n):
return feature_value, (0, 0, x_i)
if round(x_i) > round(n):
raise ValueError(
'Frequency of token {} higher than number of records {} > {}'.format(
feature_value, x_i, n) +
' This likely means you have provided tft.vocabulary with input that'
' has repeated tokens per row, rather than a set representation.')
for label_ix in range(len(global_label_counts)):
y_i = global_label_counts[label_ix]
if y_i == 0:
continue
local_mean = 0
if label_ix < len(current_accumulator.mean):
local_mean = current_accumulator.mean[label_ix]
n_i = (
_clip_probability(local_mean) * current_accumulator.weight *
current_accumulator.count)
diff_from_avg = (x_i * y_i / n) - n_i
if abs(diff_from_avg) < min_diff_from_avg:
continue
mutual_information += (
info_theory.calculate_partial_mutual_information(n_i, x_i, y_i, n))
if use_adjusted_mutual_info:
expected_mutual_information += (
info_theory.calculate_partial_expected_mutual_information(
n, x_i, y_i))
if use_adjusted_mutual_info:
# TODO(b/127366670): Consider implementing the normalization step as per
# AMI(x, y) = MI(x, y) - EMI(x, y) / (max(H(x), H(y)) - EMI(x, y))
return (feature_value, (mutual_information - expected_mutual_information,
expected_mutual_information, x_i))
else:
return (feature_value, (mutual_information, float('NaN'), x_i))