def pack_min_diff_data()

in tensorflow_model_remediation/min_diff/keras/utils/input_utils.py [0:0]


def pack_min_diff_data(original_dataset: tf.data.Dataset,
                       sensitive_group_dataset=None,
                       nonsensitive_group_dataset=None,
                       min_diff_dataset=None) -> tf.data.Dataset:
  # pyformat: disable
  """Packs `min_diff_data` with the `x` component of the original dataset.

  Arguments:
    original_dataset: `tf.data.Dataset` that was used before applying min
      diff. The output should conform to the format used in
      `tf.keras.Model.fit`.
    sensitive_group_dataset: `tf.data.Dataset` or valid MinDiff structure
      (unnested dict) of `tf.data.Dataset`s containing only examples that
      belong to the sensitive group.

      This must be passed in if `nonsensitive_group_dataset` is passed in.
      Furthermore, the `x` component for every batch should have the same
      structure as that of the `original_dataset` batches' `x` components.
    nonsensitive_group_dataset: `tf.data.Dataset` or valid MinDiff structure
      (unnested dict) of `tf.data.Dataset`s containing only examples that do
      **not** belong to the sensitive group.

      This must be passed in if `sensitive_group_dataset` is passed in.
      Furthermore, the `x` component for every batch should have the same
      structure as that of the `original_dataset` batches' `x` components.
    min_diff_dataset: `tf.data.Dataset` or valid MinDiff structure (unnested
      dict) of `tf.data.Dataset`s containing only examples to be used to
      calculate the `min_diff_loss`.

      This should only be set if neither `sensitive_group_dataset` or
      `nonsensitive_group_dataset` is passed in.
      Furthermore, the `x` component for every batch should have the same
      structure as that of the `original_dataset` batches' `x` components.

  This function should be used to create the dataset that will be passed to
  `min_diff.keras.MinDiffModel` during training and, optionally, during
  evaluation.

  The inputs should either have both `sensitive_group_dataset` and
  `nonsensitive_group_dataset` passed in and `min_diff_dataset` left unset or
  vice versa. In the case of the former, `min_diff_data` will be built using
  `utils.build_min_diff_dataset`.

  Warning: All input datasets should be batched **before** being passed in.

  Each input dataset must output a tuple in the format used in
  `tf.keras.Model.fit`. Specifically the output must be a tuple of
  length 1, 2 or 3 in the form `(x, y, sample_weight)`.

  This output will be parsed internally in the following way:

  ```
  batch = ...  # Batch from any one of the input datasets.
  x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(batch)
  ```

  Every batch from the returned `tf.data.Dataset` will contain one batch from
  each of the input datasets. Each returned batch will be a tuple of
  `(packed_inputs, original_y, original_sample_weight)` matching the length of
  `original_dataset` batches where:

  - `packed_inputs`: is an instance of `utils.MinDiffPackedInputs` containing:

    - `original_inputs`: `x` component taken directly from the
        `original_dataset` batch.
    - `min_diff_data`: batch of data formed from `sensitive_group_dataset` and
      `nonsensitive_group_dataset` (as described in
      `utils.build_min_diff_dataset`) or taken directly from `min_diff_dataset`.

  - `original_y`: is the `y` component taken directly from the
    `original_dataset` batch.
  - `original_sample_weight`: is the `sample_weight` component taken directly
    from the `original_dataset` batch.

  `min_diff_data` will be used in `min_diff.keras.MinDiffModel` when calculating
  the `min_diff_loss`. It is a tuple or structure (matching the structure of the
  inputs) of `(min_diff_x, min_diff_membership, min_diff_sample_weight)`.

  Caution: If you are passing in `min_diff_dataset` make sure that each
  `min_diff_data` batch contains about the same number of sensitive and
  nonsensitive examples as indicated by `min_diff_membership` (when passing in
  `sensitive_group_dataset` and `nonsensitive_group_dataset` this is determined
  by their batch sizes).

  Returns:
    A `tf.data.Dataset` whose output is a tuple of (`packed_inputs`,
      `original_y`, `original_sample_weight`) matching the output length
      of `original_dataset`.
  """
  # pyformat: enable
  # Either sensitive_group_dataset and nonsensitive_group_dataset are both set
  # and min_diff_dataset is not or vice versa.
  min_diff_dataset_present = min_diff_dataset is not None
  sensitive_dataset_present = sensitive_group_dataset is not None
  nonsensitive_dataset_present = nonsensitive_group_dataset is not None
  # Case where min_diff_dataset is set and the others are not.
  set_to_use_min_diff_dataset = (
      min_diff_dataset_present and
      not (sensitive_dataset_present or nonsensitive_dataset_present))
  # Case where sensitive_group_dataset and nonsensitive_group_dataset are both
  # set and min_diff_dataset is not.
  set_to_construct_min_diff_dataset = ((sensitive_dataset_present and
                                        nonsensitive_dataset_present) and
                                       not min_diff_dataset_present)
  if not (set_to_use_min_diff_dataset or set_to_construct_min_diff_dataset):
    raise ValueError(
        "Invalid arguments: You must either pass in only the `min_diff_dataset`"
        " (and leave `sensitive_group_dataset` and `nonsensitive_group_dataset`"
        " as None) or set both `sensitive_group_dataset` and "
        "`nonsensitive_group_dataset` (and leave `min_diff_dataset` as None), "
        "given: \n"
        "\n`sensitive_group_dataset`: {}"
        "\n`nonsensitive_group_dataset`: {}"
        "\n`min_diff_dataset`: {}".format(sensitive_group_dataset,
                                          nonsensitive_group_dataset,
                                          min_diff_dataset))

  # First construct the min_diff_dataset if need be.
  if set_to_construct_min_diff_dataset:
    min_diff_dataset = build_min_diff_dataset(sensitive_group_dataset,
                                              nonsensitive_group_dataset)
  else:
    # validate min_diff_dataset since it was passed in.
    structure_utils.validate_min_diff_structure(
        min_diff_dataset,
        struct_name="min_diff_dataset",
        element_type=tf.data.Dataset)

  dataset = tf.data.Dataset.zip((original_dataset, min_diff_dataset))

  def _map_fn(original_batch, min_diff_batch):
    # Unpack original batch.
    original_x, original_y, original_sample_weight = (
        tf.keras.utils.unpack_x_y_sample_weight(original_batch))

    # Assert that all min_diff_xs have the same structure as original_x.
    # TODO: Should we assert that Tensor shapes are the same (other
    #                    than number of examples).

    min_diff_xs = [
        tf.keras.utils.unpack_x_y_sample_weight(batch)[0]  # First element is x.
        for batch in structure_utils._flatten_min_diff_structure(min_diff_batch)
    ]
    for min_diff_x in min_diff_xs:
      try:
        tf.nest.assert_same_structure(original_x, min_diff_x)
      except Exception as e:
        raise type(e)(
            "The x component structure of (one of) the `min_diff_dataset`(s) "
            "does not match that of the original x structure (original shown "
            "first): {}".format(e))

    # pack min_diff_batch with original_x
    return _pack_as_original(
        original_batch,
        MinDiffPackedInputs(
            original_inputs=original_x, min_diff_data=min_diff_batch),
        original_y, original_sample_weight)

  # Reshape dataset output.
  return dataset.map(_map_fn)