in tensorflow_data_validation/statistics/generators/mutual_information.py [0:0]
def _calculate_mi(self,
examples_dict: Dict[types.FeaturePath, List[List[Any]]],
labels: List[List[Any]],
seed: int,
k: int = 3) -> Dict[types.FeaturePath, Dict[str, float]]:
"""Estimates the AMI and stores results in dict.
Args:
examples_dict: A dictionary containing features, and it's list of values.
labels: A List where the ith index represents the encoded label for the
ith example. Each encoded label is of type:
List[Optional[Union[LabelType, int]]], depending on if it is univalent
or multivalent.
seed: An int value to seed the RNG used in MI computation.
k: The number of nearest neighbors. Must be >= 3.
Returns:
Dict[FeatureName, Dict[str,float]] where the keys of the dicts are the
feature name and values are a dict where the key is
self._custom_stats_key and the values are the MI and AMI for
that
feature.
"""
result = {}
if not examples_dict:
return result
# Put each column into its own 1D array.
label_list = list(np.array(labels).T)
# Multivalent features are encoded into multivalent numeric features.
label_categorical_mask = [
(self._label_feature in self._categorical_features and
self._label_feature not in self._multivalent_features)
for _ in label_list
]
num_rows = len(list(examples_dict.values())[0])
if num_rows < k and self._allow_invalid_partitions:
logging.warn(
"Partition had %s examples for k = %s. Skipping AMI computation.",
num_rows, k)
return result
for feature_column in examples_dict:
feature_array = np.array(examples_dict[feature_column])
# A feature that is always empty cannot be predictive.
if feature_array.size == 0:
result[feature_column] = {self._custom_stats_key: 0.0}
continue
# If a categorical feature is fully unique, it cannot be predictive.
if (feature_column in self._categorical_features and
self._is_unique_array(feature_array)):
result[feature_column] = {self._custom_stats_key: 0.0}
continue
# If a feature is always null, it cannot be predictive.
all_values_are_null = False if np.sum(~pd.isnull(feature_array)) else True
if all_values_are_null:
result[feature_column] = {self._custom_stats_key: 0.0}
continue
feature_list = list(feature_array.T)
feature_categorical_mask = [
(feature_column in self._categorical_features and
feature_column not in self._multivalent_features)
for _ in feature_list
]
ami = mutual_information_util.adjusted_mutual_information(
label_list,
feature_list,
label_categorical_mask,
feature_categorical_mask,
k=k,
seed=seed)
result[feature_column] = {self._custom_stats_key: ami}
return result