in tensorflow_federated/python/core/impl/types/computation_types.py [0:0]
def to_type(spec) -> Union[TensorType, StructType, StructWithPythonType]:
"""Converts the argument into an instance of `tff.Type`.
Examples of arguments convertible to tensor types:
```python
tf.int32
(tf.int32, [10])
(tf.int32, [None])
np.int32
```
Examples of arguments convertible to flat named tuple types:
```python
[tf.int32, tf.bool]
(tf.int32, tf.bool)
[('a', tf.int32), ('b', tf.bool)]
('a', tf.int32)
collections.OrderedDict([('a', tf.int32), ('b', tf.bool)])
```
Examples of arguments convertible to nested named tuple types:
```python
(tf.int32, (tf.float32, tf.bool))
(tf.int32, (('x', tf.float32), tf.bool))
((tf.int32, [1]), (('x', (tf.float32, [2])), (tf.bool, [3])))
```
`attr.s` class instances can also be used to describe TFF types by populating
the fields with the corresponding types:
```python
@attr.s(auto_attribs=True)
class MyDataClass:
int_scalar: tf.Tensor
string_array: tf.Tensor
@classmethod
def tff_type(cls) -> tff.Type:
return tff.to_type(cls(
int_scalar=tf.int32,
string_array=tf.TensorSpec(dtype=tf.string, shape=[3]),
))
@tff.tf_computation(MyDataClass.tff_type())
def work(my_data):
assert isinstance(my_data, MyDataClass)
...
```
Args:
spec: Either an instance of `tff.Type`, or an argument convertible to
`tff.Type`.
Returns:
An instance of `tff.Type` corresponding to the given `spec`.
"""
# TODO(b/113112108): Add multiple examples of valid type specs here in the
# comments, in addition to the unit test.
if spec is None or isinstance(spec, Type):
return spec
elif _is_dtype_spec(spec):
return TensorType(spec)
elif isinstance(spec, tf.TensorSpec):
return TensorType(spec.dtype, spec.shape)
elif (isinstance(spec, tuple) and (len(spec) == 2) and
_is_dtype_spec(spec[0]) and
(isinstance(spec[1], tf.TensorShape) or
(isinstance(spec[1], (list, tuple)) and all(
(isinstance(x, int) or x is None) for x in spec[1])))):
# We found a 2-element tuple of the form (dtype, shape), where dtype is an
# instance of tf.DType, and shape is either an instance of tf.TensorShape,
# or a list, or a tuple that can be fed as argument into a tf.TensorShape.
# We thus convert this into a TensorType.
return TensorType(spec[0], spec[1])
elif isinstance(spec, (list, tuple)):
if any(py_typecheck.is_name_value_pair(e) for e in spec):
# The sequence has a (name, value) elements, the whole sequence is most
# likely intended to be a `Struct`, do not store the Python
# container.
return StructType(spec)
else:
return StructWithPythonType(spec, type(spec))
elif isinstance(spec, collections.OrderedDict):
return StructWithPythonType(spec, type(spec))
elif py_typecheck.is_attrs(spec):
return _to_type_from_attrs(spec)
elif isinstance(spec, collections.abc.Mapping):
# This is an unsupported mapping, likely a `dict`. StructType adds an
# ordering, which the original container did not have.
raise TypeError(
'Unsupported mapping type {}. Use collections.OrderedDict for '
'mappings.'.format(py_typecheck.type_string(type(spec))))
elif isinstance(spec, structure.Struct):
return StructType(structure.to_elements(spec))
elif isinstance(spec, tf.RaggedTensorSpec):
if spec.flat_values_spec is not None:
flat_values_type = to_type(spec.flat_values_spec)
else:
# We could provide a more specific shape here if `spec.shape is not None`:
# `flat_values_shape = [None] + spec.shape[spec.ragged_rank + 1:]`
# However, we can't go back from this type into a `tf.RaggedTensorSpec`,
# meaning that round-tripping a `tf.RaggedTensorSpec` through
# `type_conversions.type_to_tf_structure(computation_types.to_type(spec))`
# would *not* be a no-op: it would clear away the extra shape information,
# leading to compilation errors. This round-trip is tested in
# `type_conversions_test.py` to ensure correctness.
flat_values_shape = tf.TensorShape(None)
flat_values_type = TensorType(spec.dtype, flat_values_shape)
nested_row_splits_type = StructWithPythonType(
([(None, TensorType(spec.row_splits_dtype, [None]))] *
spec.ragged_rank), tuple)
return StructWithPythonType([('flat_values', flat_values_type),
('nested_row_splits', nested_row_splits_type)],
tf.RaggedTensor)
elif isinstance(spec, tf.SparseTensorSpec):
dtype = spec.dtype
shape = spec.shape
unknown_num_values = None
rank = None if shape is None else shape.rank
return StructWithPythonType([
('indices', TensorType(tf.int64, [unknown_num_values, rank])),
('values', TensorType(dtype, [unknown_num_values])),
('dense_shape', TensorType(tf.int64, [rank])),
], tf.SparseTensor)
else:
raise TypeError(
'Unable to interpret an argument of type {} as a type spec.'.format(
py_typecheck.type_string(type(spec))))