in tensorflow_model_remediation/tools/tutorials_utils/uci/utils.py [0:0]
def get_uci_min_diff_datasets(split='train',
sample=None,
original_batch_size=128,
min_diff_batch_size=32):
"""Creates 3 UCI datasets for MinDiff training (original, male, female).
Creates the 3 datasets that need to be packed together to use MinDiff on the
UCI dataset when targeting a FNR gap between male and female slices. The
datasets are:
- original: The original dataset used for training. This will be UCI
dataset sampled according to the `sample` parameter.
- MinDiff male: A dataset containing only positively labeled male examples.
This dataset will be a subset of the original dataset (i.e. will change in
size according to the `sample` parameter).
- MinDiff female: A dataset containing only positively labeled female
examples. This dataset will be a subset of the full data, regardless of
the value of the `sample` parameter.
Args:
split: Default: 'train'. Split for the data. Can be either 'train' or
'test'.
sample: Default: `None`. Number between `0` and `1` representing the
fraction of the data that will be used. If `None`, the entire dataset will
be used.
original_batch_size: Default: 128. Batch size for the original dataset.
min_diff_batch_size: Default: 32. Batch size for the min_diff datasets (male
and female).
Returns:
A tuple of datasets: (original, min_diff_male, min_diff_female).
"""
sampled = get_uci_data(split=split, sample=sample)
male_pos = sampled[(sampled['sex'] == ' Male') & (sampled['target'] == 1)]
# Use full dataset to get extra Female examples.
full = get_uci_data(split=split)
female_pos = full[(full['sex'] == ' Female') & (full['target'] == 1)]
# Convert to tf.data.Dataset
original_ds = df_to_dataset(
sampled, shuffle=True, batch_size=original_batch_size)
min_diff_male_ds = df_to_dataset(
male_pos, shuffle=True).batch(
min_diff_batch_size, drop_remainder=True)
min_diff_female_ds = df_to_dataset(
female_pos, shuffle=True).batch(
min_diff_batch_size, drop_remainder=True)
return original_ds, min_diff_male_ds, min_diff_female_ds