def to_type()

in tensorflow_federated/python/core/impl/types/computation_types.py [0:0]


def to_type(spec) -> Union[TensorType, StructType, StructWithPythonType]:
  """Converts the argument into an instance of `tff.Type`.

  Examples of arguments convertible to tensor types:

  ```python
  tf.int32
  (tf.int32, [10])
  (tf.int32, [None])
  np.int32
  ```

  Examples of arguments convertible to flat named tuple types:

  ```python
  [tf.int32, tf.bool]
  (tf.int32, tf.bool)
  [('a', tf.int32), ('b', tf.bool)]
  ('a', tf.int32)
  collections.OrderedDict([('a', tf.int32), ('b', tf.bool)])
  ```

  Examples of arguments convertible to nested named tuple types:

  ```python
  (tf.int32, (tf.float32, tf.bool))
  (tf.int32, (('x', tf.float32), tf.bool))
  ((tf.int32, [1]), (('x', (tf.float32, [2])), (tf.bool, [3])))
  ```

  `attr.s` class instances can also be used to describe TFF types by populating
  the fields with the corresponding types:

  ```python
  @attr.s(auto_attribs=True)
  class MyDataClass:
    int_scalar: tf.Tensor
    string_array: tf.Tensor

    @classmethod
    def tff_type(cls) -> tff.Type:
      return tff.to_type(cls(
        int_scalar=tf.int32,
        string_array=tf.TensorSpec(dtype=tf.string, shape=[3]),
      ))

  @tff.tf_computation(MyDataClass.tff_type())
  def work(my_data):
    assert isinstance(my_data, MyDataClass)
    ...
  ```

  Args:
    spec: Either an instance of `tff.Type`, or an argument convertible to
      `tff.Type`.

  Returns:
    An instance of `tff.Type` corresponding to the given `spec`.
  """
  # TODO(b/113112108): Add multiple examples of valid type specs here in the
  # comments, in addition to the unit test.
  if spec is None or isinstance(spec, Type):
    return spec
  elif _is_dtype_spec(spec):
    return TensorType(spec)
  elif isinstance(spec, tf.TensorSpec):
    return TensorType(spec.dtype, spec.shape)
  elif (isinstance(spec, tuple) and (len(spec) == 2) and
        _is_dtype_spec(spec[0]) and
        (isinstance(spec[1], tf.TensorShape) or
         (isinstance(spec[1], (list, tuple)) and all(
             (isinstance(x, int) or x is None) for x in spec[1])))):
    # We found a 2-element tuple of the form (dtype, shape), where dtype is an
    # instance of tf.DType, and shape is either an instance of tf.TensorShape,
    # or a list, or a tuple that can be fed as argument into a tf.TensorShape.
    # We thus convert this into a TensorType.
    return TensorType(spec[0], spec[1])
  elif isinstance(spec, (list, tuple)):
    if any(py_typecheck.is_name_value_pair(e) for e in spec):
      # The sequence has a (name, value) elements, the whole sequence is most
      # likely intended to be a `Struct`, do not store the Python
      # container.
      return StructType(spec)
    else:
      return StructWithPythonType(spec, type(spec))
  elif isinstance(spec, collections.OrderedDict):
    return StructWithPythonType(spec, type(spec))
  elif py_typecheck.is_attrs(spec):
    return _to_type_from_attrs(spec)
  elif isinstance(spec, collections.abc.Mapping):
    # This is an unsupported mapping, likely a `dict`. StructType adds an
    # ordering, which the original container did not have.
    raise TypeError(
        'Unsupported mapping type {}. Use collections.OrderedDict for '
        'mappings.'.format(py_typecheck.type_string(type(spec))))
  elif isinstance(spec, structure.Struct):
    return StructType(structure.to_elements(spec))
  elif isinstance(spec, tf.RaggedTensorSpec):
    if spec.flat_values_spec is not None:
      flat_values_type = to_type(spec.flat_values_spec)
    else:
      # We could provide a more specific shape here if `spec.shape is not None`:
      # `flat_values_shape = [None] + spec.shape[spec.ragged_rank + 1:]`
      # However, we can't go back from this type into a `tf.RaggedTensorSpec`,
      # meaning that round-tripping a `tf.RaggedTensorSpec` through
      # `type_conversions.type_to_tf_structure(computation_types.to_type(spec))`
      # would *not* be a no-op: it would clear away the extra shape information,
      # leading to compilation errors. This round-trip is tested in
      # `type_conversions_test.py` to ensure correctness.
      flat_values_shape = tf.TensorShape(None)
      flat_values_type = TensorType(spec.dtype, flat_values_shape)
    nested_row_splits_type = StructWithPythonType(
        ([(None, TensorType(spec.row_splits_dtype, [None]))] *
         spec.ragged_rank), tuple)
    return StructWithPythonType([('flat_values', flat_values_type),
                                 ('nested_row_splits', nested_row_splits_type)],
                                tf.RaggedTensor)
  elif isinstance(spec, tf.SparseTensorSpec):
    dtype = spec.dtype
    shape = spec.shape
    unknown_num_values = None
    rank = None if shape is None else shape.rank
    return StructWithPythonType([
        ('indices', TensorType(tf.int64, [unknown_num_values, rank])),
        ('values', TensorType(dtype, [unknown_num_values])),
        ('dense_shape', TensorType(tf.int64, [rank])),
    ], tf.SparseTensor)
  else:
    raise TypeError(
        'Unable to interpret an argument of type {} as a type spec.'.format(
            py_typecheck.type_string(type(spec))))