def coerce_dataset_elements_to_tff_type_spec()

in tensorflow_federated/python/core/impl/utils/tensorflow_utils.py [0:0]


def coerce_dataset_elements_to_tff_type_spec(dataset, element_type):
  """Map the elements of a dataset to a specified type.

  This is used to coerce a `tf.data.Dataset` that may have lost the ordering
  of dictionary keys back into a `collections.OrderedDict` (required by TFF).

  Args:
    dataset: a `tf.data.Dataset` instance.
    element_type: a `tff.Type` specifying the type of the elements of `dataset`.
      Must be a `tff.TensorType` or `tff.StructType`.

  Returns:
    A `tf.data.Dataset` whose output types are compatible with
    `element_type`.

  Raises:
    ValueError: if the elements of `dataset` cannot be coerced into
      `element_type`.
  """
  py_typecheck.check_type(dataset,
                          type_conversions.TF_DATASET_REPRESENTATION_TYPES)
  py_typecheck.check_type(element_type, computation_types.Type)
  if element_type.is_tensor():
    return dataset
  # This is a similar to `reference_context.to_representation_for_type`,
  # look for opportunities to consolidate?
  def _to_representative_value(type_spec, elements):
    """Convert to a container to a type understood by TF and TFF."""
    if type_spec.is_tensor():
      return elements
    elif type_spec.is_struct_with_python():
      if tf.is_tensor(elements):
        # In this case we have a singleton tuple tensor that may have been
        # unwrapped by tf.data.
        elements = [elements]
      py_type = computation_types.StructWithPythonType.get_container_type(
          type_spec)
      field_types = structure.iter_elements(type_spec)
      if (issubclass(py_type, collections.abc.Mapping) or
          py_typecheck.is_attrs(py_type)):
        values = collections.OrderedDict(
            (name, _to_representative_value(field_type, elements[name]))
            for name, field_type in field_types)
        return py_type(**values)
      else:
        values = [
            _to_representative_value(field_type, e)
            for (_, field_type), e in zip(field_types, elements)
        ]
        if py_typecheck.is_named_tuple(py_type):
          return py_type(*values)
        return py_type(values)
    elif type_spec.is_struct():
      field_types = structure.to_elements(type_spec)
      is_all_named = all([name is not None for name, _ in field_types])
      if is_all_named:
        if py_typecheck.is_named_tuple(elements):
          values = collections.OrderedDict(
              (name, _to_representative_value(field_type, e))
              for (name, field_type), e in zip(field_types, elements))
          return type(elements)(**values)
        else:
          values = [(name, _to_representative_value(field_type, elements[name]))
                    for name, field_type in field_types]
          return collections.OrderedDict(values)
      else:
        return tuple(
            _to_representative_value(t, e) for t, e in zip(type_spec, elements))
    else:
      raise ValueError(
          'Coercing a dataset with elements of expected type {!s}, '
          'produced a value with incompatible type `{!s}. Value: '
          '{!s}'.format(type_spec, type(elements), elements))

  # tf.data.Dataset of tuples will unwrap the tuple in the `map()` call, so we
  # must pass a function taking *args. However, if the call was originally only
  # a single tuple, it is now "double wrapped" and must be unwrapped before
  # traversing.
  def _unwrap_args(*args):
    if len(args) == 1:
      return _to_representative_value(element_type, args[0])
    else:
      return _to_representative_value(element_type, args)

  return dataset.map(_unwrap_args)