def get_uci_min_diff_datasets()

in tensorflow_model_remediation/tools/tutorials_utils/uci/utils.py [0:0]


def get_uci_min_diff_datasets(split='train',
                              sample=None,
                              original_batch_size=128,
                              min_diff_batch_size=32):
  """Creates 3 UCI datasets for MinDiff training (original, male, female).

  Creates the 3 datasets that need to be packed together to use MinDiff on the
  UCI dataset when targeting a FNR gap between male and female slices. The
  datasets are:
    - original: The original dataset used for training. This will be UCI
      dataset sampled according to the `sample` parameter.
    - MinDiff male: A dataset containing only positively labeled male examples.
      This dataset will be a subset of the original dataset (i.e. will change in
      size according to the `sample` parameter).
    - MinDiff female: A dataset containing only positively labeled female
      examples. This dataset will be a subset of the full data, regardless of
      the value of the `sample` parameter.

  Args:
    split: Default: 'train'. Split for the data. Can be either 'train' or
      'test'.
    sample: Default: `None`. Number between `0` and `1` representing the
      fraction of the data that will be used. If `None`, the entire dataset will
      be used.
    original_batch_size: Default: 128. Batch size for the original dataset.
    min_diff_batch_size: Default: 32. Batch size for the min_diff datasets (male
      and female).

  Returns:
    A tuple of datasets: (original, min_diff_male, min_diff_female).
  """
  sampled = get_uci_data(split=split, sample=sample)
  male_pos = sampled[(sampled['sex'] == ' Male') & (sampled['target'] == 1)]

  # Use full dataset to get extra Female examples.
  full = get_uci_data(split=split)
  female_pos = full[(full['sex'] == ' Female') & (full['target'] == 1)]

  # Convert to tf.data.Dataset
  original_ds = df_to_dataset(
      sampled, shuffle=True, batch_size=original_batch_size)
  min_diff_male_ds = df_to_dataset(
      male_pos, shuffle=True).batch(
          min_diff_batch_size, drop_remainder=True)
  min_diff_female_ds = df_to_dataset(
      female_pos, shuffle=True).batch(
          min_diff_batch_size, drop_remainder=True)

  return original_ds, min_diff_male_ds, min_diff_female_ds