in tensorflow_model_remediation/min_diff/keras/utils/input_utils.py [0:0]
def pack_min_diff_data(original_dataset: tf.data.Dataset,
sensitive_group_dataset=None,
nonsensitive_group_dataset=None,
min_diff_dataset=None) -> tf.data.Dataset:
# pyformat: disable
"""Packs `min_diff_data` with the `x` component of the original dataset.
Arguments:
original_dataset: `tf.data.Dataset` that was used before applying min
diff. The output should conform to the format used in
`tf.keras.Model.fit`.
sensitive_group_dataset: `tf.data.Dataset` or valid MinDiff structure
(unnested dict) of `tf.data.Dataset`s containing only examples that
belong to the sensitive group.
This must be passed in if `nonsensitive_group_dataset` is passed in.
Furthermore, the `x` component for every batch should have the same
structure as that of the `original_dataset` batches' `x` components.
nonsensitive_group_dataset: `tf.data.Dataset` or valid MinDiff structure
(unnested dict) of `tf.data.Dataset`s containing only examples that do
**not** belong to the sensitive group.
This must be passed in if `sensitive_group_dataset` is passed in.
Furthermore, the `x` component for every batch should have the same
structure as that of the `original_dataset` batches' `x` components.
min_diff_dataset: `tf.data.Dataset` or valid MinDiff structure (unnested
dict) of `tf.data.Dataset`s containing only examples to be used to
calculate the `min_diff_loss`.
This should only be set if neither `sensitive_group_dataset` or
`nonsensitive_group_dataset` is passed in.
Furthermore, the `x` component for every batch should have the same
structure as that of the `original_dataset` batches' `x` components.
This function should be used to create the dataset that will be passed to
`min_diff.keras.MinDiffModel` during training and, optionally, during
evaluation.
The inputs should either have both `sensitive_group_dataset` and
`nonsensitive_group_dataset` passed in and `min_diff_dataset` left unset or
vice versa. In the case of the former, `min_diff_data` will be built using
`utils.build_min_diff_dataset`.
Warning: All input datasets should be batched **before** being passed in.
Each input dataset must output a tuple in the format used in
`tf.keras.Model.fit`. Specifically the output must be a tuple of
length 1, 2 or 3 in the form `(x, y, sample_weight)`.
This output will be parsed internally in the following way:
```
batch = ... # Batch from any one of the input datasets.
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(batch)
```
Every batch from the returned `tf.data.Dataset` will contain one batch from
each of the input datasets. Each returned batch will be a tuple of
`(packed_inputs, original_y, original_sample_weight)` matching the length of
`original_dataset` batches where:
- `packed_inputs`: is an instance of `utils.MinDiffPackedInputs` containing:
- `original_inputs`: `x` component taken directly from the
`original_dataset` batch.
- `min_diff_data`: batch of data formed from `sensitive_group_dataset` and
`nonsensitive_group_dataset` (as described in
`utils.build_min_diff_dataset`) or taken directly from `min_diff_dataset`.
- `original_y`: is the `y` component taken directly from the
`original_dataset` batch.
- `original_sample_weight`: is the `sample_weight` component taken directly
from the `original_dataset` batch.
`min_diff_data` will be used in `min_diff.keras.MinDiffModel` when calculating
the `min_diff_loss`. It is a tuple or structure (matching the structure of the
inputs) of `(min_diff_x, min_diff_membership, min_diff_sample_weight)`.
Caution: If you are passing in `min_diff_dataset` make sure that each
`min_diff_data` batch contains about the same number of sensitive and
nonsensitive examples as indicated by `min_diff_membership` (when passing in
`sensitive_group_dataset` and `nonsensitive_group_dataset` this is determined
by their batch sizes).
Returns:
A `tf.data.Dataset` whose output is a tuple of (`packed_inputs`,
`original_y`, `original_sample_weight`) matching the output length
of `original_dataset`.
"""
# pyformat: enable
# Either sensitive_group_dataset and nonsensitive_group_dataset are both set
# and min_diff_dataset is not or vice versa.
min_diff_dataset_present = min_diff_dataset is not None
sensitive_dataset_present = sensitive_group_dataset is not None
nonsensitive_dataset_present = nonsensitive_group_dataset is not None
# Case where min_diff_dataset is set and the others are not.
set_to_use_min_diff_dataset = (
min_diff_dataset_present and
not (sensitive_dataset_present or nonsensitive_dataset_present))
# Case where sensitive_group_dataset and nonsensitive_group_dataset are both
# set and min_diff_dataset is not.
set_to_construct_min_diff_dataset = ((sensitive_dataset_present and
nonsensitive_dataset_present) and
not min_diff_dataset_present)
if not (set_to_use_min_diff_dataset or set_to_construct_min_diff_dataset):
raise ValueError(
"Invalid arguments: You must either pass in only the `min_diff_dataset`"
" (and leave `sensitive_group_dataset` and `nonsensitive_group_dataset`"
" as None) or set both `sensitive_group_dataset` and "
"`nonsensitive_group_dataset` (and leave `min_diff_dataset` as None), "
"given: \n"
"\n`sensitive_group_dataset`: {}"
"\n`nonsensitive_group_dataset`: {}"
"\n`min_diff_dataset`: {}".format(sensitive_group_dataset,
nonsensitive_group_dataset,
min_diff_dataset))
# First construct the min_diff_dataset if need be.
if set_to_construct_min_diff_dataset:
min_diff_dataset = build_min_diff_dataset(sensitive_group_dataset,
nonsensitive_group_dataset)
else:
# validate min_diff_dataset since it was passed in.
structure_utils.validate_min_diff_structure(
min_diff_dataset,
struct_name="min_diff_dataset",
element_type=tf.data.Dataset)
dataset = tf.data.Dataset.zip((original_dataset, min_diff_dataset))
def _map_fn(original_batch, min_diff_batch):
# Unpack original batch.
original_x, original_y, original_sample_weight = (
tf.keras.utils.unpack_x_y_sample_weight(original_batch))
# Assert that all min_diff_xs have the same structure as original_x.
# TODO: Should we assert that Tensor shapes are the same (other
# than number of examples).
min_diff_xs = [
tf.keras.utils.unpack_x_y_sample_weight(batch)[0] # First element is x.
for batch in structure_utils._flatten_min_diff_structure(min_diff_batch)
]
for min_diff_x in min_diff_xs:
try:
tf.nest.assert_same_structure(original_x, min_diff_x)
except Exception as e:
raise type(e)(
"The x component structure of (one of) the `min_diff_dataset`(s) "
"does not match that of the original x structure (original shown "
"first): {}".format(e))
# pack min_diff_batch with original_x
return _pack_as_original(
original_batch,
MinDiffPackedInputs(
original_inputs=original_x, min_diff_data=min_diff_batch),
original_y, original_sample_weight)
# Reshape dataset output.
return dataset.map(_map_fn)