in tensorflow_model_remediation/min_diff/keras/utils/input_utils.py [0:0]
def build_min_diff_dataset(sensitive_group_dataset,
nonsensitive_group_dataset) -> tf.data.Dataset:
# pyformat: disable
"""Build MinDiff dataset from sensitive and nonsensitive datasets.
Arguments:
sensitive_group_dataset: `tf.data.Dataset` or valid MinDiff structure
(unnested dict) of `tf.data.Dataset`s containing only examples that
belong to the sensitive group.
nonsensitive_group_dataset: `tf.data.Dataset` or valid MinDiff structure
(unnested dict) of `tf.data.Dataset`s containing only examples that do
**not** belong to the sensitive group.
This function builds a `tf.data.Dataset` containing examples that are meant to
only be used when calculating a `min_diff_loss`. This resulting dataset will
need to be packed with the original dataset used for the original task of the
model which can be done by calling `utils.pack_min_diff_data`.
Warning: All input datasets should be batched **before** being passed in.
Each input dataset must output a tuple in the format used in
`tf.keras.Model.fit`. Specifically the output must be a tuple of
length 1, 2 or 3 in the form `(x, y, sample_weight)`.
This output will be parsed internally in the following way:
```
batch = ... # Batch from any of the input datasets.
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(batch)
```
Note: the `y` component of input datasets will be ignored completely so it can
be set to `None` or any other arbitrary value. If `sample_weight` is not
included, it can be left out entirely.
Every batch from the returned `tf.data.Dataset` will contain one batch from
each of the input datasets. Each returned batch will be a tuple or structure
(matching the structure of the inputs) of `(min_diff_x, min_diff_membership,
min_diff_sample_weight)` where, for each pair of input datasets:
- `min_diff_x`: is formed by concatenating the `x` components of the paired
datasets. The structure of these must match. If they don't the dataset will
raise an error at the first batch.
- `min_diff_membership`: is a tensor of size `[min_diff_batch_size, 1]`
indicating which dataset each example comes from (`1.0` for
`sensitive_group_dataset` and `0.0` for `nonsensitive_group_dataset`).
- `min_diff_sample_weight`: is formed by concatenating the `sample_weight`
components of the paired datasets. If both are `None`, then this will be set
to `None`. If only one is `None`, it is replaced with a `Tensor` of ones of
the appropriate shape.
Returns:
A `tf.data.Dataset` whose output is a tuple or structure (matching the
structure of the inputs) of `(min_diff_x, min_diff_membership,
min_diff_sample_weight)`.
Raises:
ValueError: If either `sensitive_group_dataset` or
`nonsensitive_group_dataset` is not a valid MinDiff structure (unnested
dict).
ValueError: If `sensitive_group_dataset` and `nonsensitive_group_dataset` do
not have the same structure.
"""
# pyformat: enable
# validate structures.
structure_utils.validate_min_diff_structure(
sensitive_group_dataset,
struct_name="sensitive_group_dataset",
element_type=tf.data.Dataset)
structure_utils.validate_min_diff_structure(
nonsensitive_group_dataset,
struct_name="nonsensitive_group_dataset",
element_type=tf.data.Dataset)
try:
structure_utils._assert_same_min_diff_structure(sensitive_group_dataset,
nonsensitive_group_dataset)
except Exception as e:
raise type(e)("`sensitive_group_dataset` and `nonsensitive_group_dataset` "
"do not have the same structure:\n{}".format(e))
sensitive_group_dataset = tf.nest.map_structure(
lambda dataset: dataset.repeat(), sensitive_group_dataset)
nonsensitive_group_dataset = tf.nest.map_structure(
lambda dataset: dataset.repeat(), nonsensitive_group_dataset)
dataset = tf.data.Dataset.zip(
(sensitive_group_dataset, nonsensitive_group_dataset))
def _build_single_batch(single_sensitive_batch, single_nonsensitive_batch):
# Unpack both batches.
sensitive_x, _, sensitive_sample_weight = (
tf.keras.utils.unpack_x_y_sample_weight(single_sensitive_batch))
nonsensitive_x, _, nonsensitive_sample_weight = (
tf.keras.utils.unpack_x_y_sample_weight(single_nonsensitive_batch))
# sensitive_x and nonsensitive_x must have the same structure.
try:
tf.nest.assert_same_structure(sensitive_x, nonsensitive_x)
except Exception as e:
raise type(e)("The x component structure of (one of) the "
"`sensitive_group_dataset`(s) does not match that of the "
"(corresponding) `nonsensitive_group_dataset` x structure "
"(sensitive shown first): {}".format(e))
# Create min_diff_data.
# Merge sensitive_x and nonsensitive_x to form min_diff_x.
flat_sensitive_x = tf.nest.flatten(sensitive_x)
flat_nonsensitive_x = tf.nest.flatten(nonsensitive_x)
flat_min_diff_x = [
_tensor_concat(t1, t2)
for t1, t2 in zip(flat_sensitive_x, flat_nonsensitive_x)
]
min_diff_x = tf.nest.pack_sequence_as(sensitive_x, flat_min_diff_x)
# min_diff_membership indicates which dataset each example comes from.
sensitive_shape = [tf.shape(flat_sensitive_x[0])[0], 1]
nonsensitive_shape = [tf.shape(flat_nonsensitive_x[0])[0], 1]
min_diff_membership = tf.concat(
axis=0,
values=[
tf.ones(sensitive_shape, dtype=tf.float32),
tf.zeros(nonsensitive_shape, dtype=tf.float32)
])
# min_diff_sample_weight is the concatenation of both sample_weights.
min_diff_sample_weight = None # Default if both sample_weights are None.
if (sensitive_sample_weight is not None or
nonsensitive_sample_weight is not None):
if sensitive_sample_weight is None:
sensitive_sample_weight = tf.ones(sensitive_shape, dtype=tf.float32)
elif nonsensitive_sample_weight is None:
nonsensitive_sample_weight = tf.ones(
nonsensitive_shape, dtype=tf.float32)
min_diff_sample_weight = tf.concat(
[sensitive_sample_weight, nonsensitive_sample_weight], axis=0)
# Pack the three components and return them
return tf.keras.utils.pack_x_y_sample_weight(min_diff_x,
min_diff_membership,
min_diff_sample_weight)
def _map_fn(sensitive_batch, nonsensitive_batch):
flat_sensitive_batch = structure_utils._flatten_min_diff_structure(
sensitive_batch)
flat_nonsensitive_batch = structure_utils._flatten_min_diff_structure(
nonsensitive_batch)
flat_min_diff_data = [
_build_single_batch(single_sensitive_batch, single_nonsensitive_batch)
for single_sensitive_batch, single_nonsensitive_batch in zip(
flat_sensitive_batch, flat_nonsensitive_batch)
]
return structure_utils._pack_min_diff_sequence_as(sensitive_batch,
flat_min_diff_data)
# Reshape dataset output.
return dataset.map(_map_fn)