in tensorflow_federated/python/core/impl/utils/tensorflow_utils.py [0:0]
def coerce_dataset_elements_to_tff_type_spec(dataset, element_type):
"""Map the elements of a dataset to a specified type.
This is used to coerce a `tf.data.Dataset` that may have lost the ordering
of dictionary keys back into a `collections.OrderedDict` (required by TFF).
Args:
dataset: a `tf.data.Dataset` instance.
element_type: a `tff.Type` specifying the type of the elements of `dataset`.
Must be a `tff.TensorType` or `tff.StructType`.
Returns:
A `tf.data.Dataset` whose output types are compatible with
`element_type`.
Raises:
ValueError: if the elements of `dataset` cannot be coerced into
`element_type`.
"""
py_typecheck.check_type(dataset,
type_conversions.TF_DATASET_REPRESENTATION_TYPES)
py_typecheck.check_type(element_type, computation_types.Type)
if element_type.is_tensor():
return dataset
# This is a similar to `reference_context.to_representation_for_type`,
# look for opportunities to consolidate?
def _to_representative_value(type_spec, elements):
"""Convert to a container to a type understood by TF and TFF."""
if type_spec.is_tensor():
return elements
elif type_spec.is_struct_with_python():
if tf.is_tensor(elements):
# In this case we have a singleton tuple tensor that may have been
# unwrapped by tf.data.
elements = [elements]
py_type = computation_types.StructWithPythonType.get_container_type(
type_spec)
field_types = structure.iter_elements(type_spec)
if (issubclass(py_type, collections.abc.Mapping) or
py_typecheck.is_attrs(py_type)):
values = collections.OrderedDict(
(name, _to_representative_value(field_type, elements[name]))
for name, field_type in field_types)
return py_type(**values)
else:
values = [
_to_representative_value(field_type, e)
for (_, field_type), e in zip(field_types, elements)
]
if py_typecheck.is_named_tuple(py_type):
return py_type(*values)
return py_type(values)
elif type_spec.is_struct():
field_types = structure.to_elements(type_spec)
is_all_named = all([name is not None for name, _ in field_types])
if is_all_named:
if py_typecheck.is_named_tuple(elements):
values = collections.OrderedDict(
(name, _to_representative_value(field_type, e))
for (name, field_type), e in zip(field_types, elements))
return type(elements)(**values)
else:
values = [(name, _to_representative_value(field_type, elements[name]))
for name, field_type in field_types]
return collections.OrderedDict(values)
else:
return tuple(
_to_representative_value(t, e) for t, e in zip(type_spec, elements))
else:
raise ValueError(
'Coercing a dataset with elements of expected type {!s}, '
'produced a value with incompatible type `{!s}. Value: '
'{!s}'.format(type_spec, type(elements), elements))
# tf.data.Dataset of tuples will unwrap the tuple in the `map()` call, so we
# must pass a function taking *args. However, if the call was originally only
# a single tuple, it is now "double wrapped" and must be unwrapped before
# traversing.
def _unwrap_args(*args):
if len(args) == 1:
return _to_representative_value(element_type, args[0])
else:
return _to_representative_value(element_type, args)
return dataset.map(_unwrap_args)