def build_min_diff_dataset()

in tensorflow_model_remediation/min_diff/keras/utils/input_utils.py [0:0]


def build_min_diff_dataset(sensitive_group_dataset,
                           nonsensitive_group_dataset) -> tf.data.Dataset:
  # pyformat: disable
  """Build MinDiff dataset from sensitive and nonsensitive datasets.

  Arguments:
    sensitive_group_dataset: `tf.data.Dataset` or valid MinDiff structure
      (unnested dict) of `tf.data.Dataset`s containing only examples that
      belong to the sensitive group.
    nonsensitive_group_dataset: `tf.data.Dataset` or valid MinDiff structure
      (unnested dict) of `tf.data.Dataset`s containing only examples that do
      **not** belong to the sensitive group.

  This function builds a `tf.data.Dataset` containing examples that are meant to
  only be used when calculating a `min_diff_loss`. This resulting dataset will
  need to be packed with the original dataset used for the original task of the
  model which can be done by calling `utils.pack_min_diff_data`.

  Warning: All input datasets should be batched **before** being passed in.

  Each input dataset must output a tuple in the format used in
  `tf.keras.Model.fit`. Specifically the output must be a tuple of
  length 1, 2 or 3 in the form `(x, y, sample_weight)`.

  This output will be parsed internally in the following way:

  ```
  batch = ...  # Batch from any of the input datasets.
  x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(batch)
  ```

  Note: the `y` component of input datasets will be ignored completely so it can
  be set to `None` or any other arbitrary value. If `sample_weight` is not
  included, it can be left out entirely.

  Every batch from the returned `tf.data.Dataset` will contain one batch from
  each of the input datasets. Each returned batch will be a tuple or structure
  (matching the structure of the inputs) of `(min_diff_x, min_diff_membership,
  min_diff_sample_weight)` where, for each pair of input datasets:

  - `min_diff_x`: is formed by concatenating the `x` components of the paired
    datasets. The structure of these must match. If they don't the dataset will
    raise an error at the first batch.
  - `min_diff_membership`: is a tensor of size `[min_diff_batch_size, 1]`
    indicating which dataset each example comes from (`1.0` for
    `sensitive_group_dataset` and `0.0` for `nonsensitive_group_dataset`).
  - `min_diff_sample_weight`: is formed by concatenating the `sample_weight`
    components of the paired datasets. If both are `None`, then this will be set
    to `None`. If only one is `None`, it is replaced with a `Tensor` of ones of
    the appropriate shape.

  Returns:
    A `tf.data.Dataset` whose output is a tuple or structure (matching the
      structure of the inputs) of `(min_diff_x, min_diff_membership,
      min_diff_sample_weight)`.

  Raises:
    ValueError: If either `sensitive_group_dataset` or
      `nonsensitive_group_dataset` is not a valid MinDiff structure (unnested
      dict).
    ValueError: If `sensitive_group_dataset` and `nonsensitive_group_dataset` do
      not have the same structure.
  """
  # pyformat: enable
  # validate structures.
  structure_utils.validate_min_diff_structure(
      sensitive_group_dataset,
      struct_name="sensitive_group_dataset",
      element_type=tf.data.Dataset)
  structure_utils.validate_min_diff_structure(
      nonsensitive_group_dataset,
      struct_name="nonsensitive_group_dataset",
      element_type=tf.data.Dataset)
  try:

    structure_utils._assert_same_min_diff_structure(sensitive_group_dataset,
                                                    nonsensitive_group_dataset)
  except Exception as e:
    raise type(e)("`sensitive_group_dataset` and `nonsensitive_group_dataset` "
                  "do not have the same structure:\n{}".format(e))

  sensitive_group_dataset = tf.nest.map_structure(
      lambda dataset: dataset.repeat(), sensitive_group_dataset)
  nonsensitive_group_dataset = tf.nest.map_structure(
      lambda dataset: dataset.repeat(), nonsensitive_group_dataset)

  dataset = tf.data.Dataset.zip(
      (sensitive_group_dataset, nonsensitive_group_dataset))

  def _build_single_batch(single_sensitive_batch, single_nonsensitive_batch):
    # Unpack both batches.
    sensitive_x, _, sensitive_sample_weight = (
        tf.keras.utils.unpack_x_y_sample_weight(single_sensitive_batch))
    nonsensitive_x, _, nonsensitive_sample_weight = (
        tf.keras.utils.unpack_x_y_sample_weight(single_nonsensitive_batch))

    # sensitive_x and nonsensitive_x must have the same structure.
    try:
      tf.nest.assert_same_structure(sensitive_x, nonsensitive_x)
    except Exception as e:
      raise type(e)("The x component structure of (one of) the "
                    "`sensitive_group_dataset`(s) does not match that of the "
                    "(corresponding) `nonsensitive_group_dataset` x structure "
                    "(sensitive shown first): {}".format(e))

    # Create min_diff_data.
    # Merge sensitive_x and nonsensitive_x to form min_diff_x.
    flat_sensitive_x = tf.nest.flatten(sensitive_x)
    flat_nonsensitive_x = tf.nest.flatten(nonsensitive_x)
    flat_min_diff_x = [
        _tensor_concat(t1, t2)
        for t1, t2 in zip(flat_sensitive_x, flat_nonsensitive_x)
    ]
    min_diff_x = tf.nest.pack_sequence_as(sensitive_x, flat_min_diff_x)

    # min_diff_membership indicates which dataset each example comes from.
    sensitive_shape = [tf.shape(flat_sensitive_x[0])[0], 1]
    nonsensitive_shape = [tf.shape(flat_nonsensitive_x[0])[0], 1]
    min_diff_membership = tf.concat(
        axis=0,
        values=[
            tf.ones(sensitive_shape, dtype=tf.float32),
            tf.zeros(nonsensitive_shape, dtype=tf.float32)
        ])
    # min_diff_sample_weight is the concatenation of both sample_weights.
    min_diff_sample_weight = None  # Default if both sample_weights are None.
    if (sensitive_sample_weight is not None or
        nonsensitive_sample_weight is not None):
      if sensitive_sample_weight is None:
        sensitive_sample_weight = tf.ones(sensitive_shape, dtype=tf.float32)
      elif nonsensitive_sample_weight is None:
        nonsensitive_sample_weight = tf.ones(
            nonsensitive_shape, dtype=tf.float32)
      min_diff_sample_weight = tf.concat(
          [sensitive_sample_weight, nonsensitive_sample_weight], axis=0)

    # Pack the three components and return them
    return tf.keras.utils.pack_x_y_sample_weight(min_diff_x,
                                                 min_diff_membership,
                                                 min_diff_sample_weight)

  def _map_fn(sensitive_batch, nonsensitive_batch):

    flat_sensitive_batch = structure_utils._flatten_min_diff_structure(
        sensitive_batch)
    flat_nonsensitive_batch = structure_utils._flatten_min_diff_structure(
        nonsensitive_batch)

    flat_min_diff_data = [
        _build_single_batch(single_sensitive_batch, single_nonsensitive_batch)
        for single_sensitive_batch, single_nonsensitive_batch in zip(
            flat_sensitive_batch, flat_nonsensitive_batch)
    ]

    return structure_utils._pack_min_diff_sequence_as(sensitive_batch,
                                                      flat_min_diff_data)

  # Reshape dataset output.
  return dataset.map(_map_fn)