in tensorflow_federated/python/core/impl/types/type_conversions.py [0:0]
def infer_type(arg: Any) -> Optional[computation_types.Type]:
"""Infers the TFF type of the argument (a `computation_types.Type` instance).
WARNING: This function is only partially implemented.
The kinds of arguments that are currently correctly recognized:
- tensors, variables, and data sets,
- things that are convertible to tensors (including numpy arrays, builtin
types, as well as lists and tuples of any of the above, etc.),
- nested lists, tuples, namedtuples, anonymous tuples, dict, and OrderedDicts.
Args:
arg: The argument, the TFF type of which to infer.
Returns:
Either an instance of `computation_types.Type`, or `None` if the argument is
`None`.
"""
if arg is None:
return None
elif isinstance(arg, typed_object.TypedObject):
return arg.type_signature
elif tf.is_tensor(arg):
# `tf.is_tensor` returns true for some things that are not actually single
# `tf.Tensor`s, including `tf.SparseTensor`s and `tf.RaggedTensor`s.
if isinstance(arg, tf.RaggedTensor):
return computation_types.StructWithPythonType(
(('flat_values', infer_type(arg.flat_values)),
('nested_row_splits', infer_type(arg.nested_row_splits))),
tf.RaggedTensor)
elif isinstance(arg, tf.SparseTensor):
return computation_types.StructWithPythonType(
(('indices', infer_type(arg.indices)),
('values', infer_type(arg.values)),
('dense_shape', infer_type(arg.dense_shape))), tf.SparseTensor)
else:
return computation_types.TensorType(arg.dtype.base_dtype, arg.shape)
elif isinstance(arg, TF_DATASET_REPRESENTATION_TYPES):
element_type = computation_types.to_type(arg.element_spec)
return computation_types.SequenceType(element_type)
elif isinstance(arg, structure.Struct):
return computation_types.StructType([
(k, infer_type(v)) if k else infer_type(v)
for k, v in structure.iter_elements(arg)
])
elif py_typecheck.is_attrs(arg):
items = attr.asdict(
arg, dict_factory=collections.OrderedDict, recurse=False)
return computation_types.StructWithPythonType(
[(k, infer_type(v)) for k, v in items.items()], type(arg))
elif py_typecheck.is_named_tuple(arg):
# In Python 3.8 and later `_asdict` no longer return OrdereDict, rather a
# regular `dict`.
items = collections.OrderedDict(arg._asdict())
return computation_types.StructWithPythonType(
[(k, infer_type(v)) for k, v in items.items()], type(arg))
elif isinstance(arg, dict):
if isinstance(arg, collections.OrderedDict):
items = arg.items()
else:
items = sorted(arg.items())
return computation_types.StructWithPythonType(
[(k, infer_type(v)) for k, v in items], type(arg))
elif isinstance(arg, (tuple, list)):
elements = []
all_elements_named = True
for element in arg:
all_elements_named &= py_typecheck.is_name_value_pair(element)
elements.append(infer_type(element))
# If this is a tuple of (name, value) pairs, the caller most likely intended
# this to be a StructType, so we avoid storing the Python container.
if elements and all_elements_named:
return computation_types.StructType(elements)
else:
return computation_types.StructWithPythonType(elements, type(arg))
elif isinstance(arg, str):
return computation_types.TensorType(tf.string)
elif isinstance(arg, (np.generic, np.ndarray)):
return computation_types.TensorType(
tf.dtypes.as_dtype(arg.dtype), arg.shape)
else:
arg_type = type(arg)
if arg_type is bool:
return computation_types.TensorType(tf.bool)
elif arg_type is int:
# Chose the integral type based on value.
if arg > tf.int64.max or arg < tf.int64.min:
raise TypeError('No integral type support for values outside range '
f'[{tf.int64.min}, {tf.int64.max}]. Got: {arg}')
elif arg > tf.int32.max or arg < tf.int32.min:
return computation_types.TensorType(tf.int64)
else:
return computation_types.TensorType(tf.int32)
elif arg_type is float:
return computation_types.TensorType(tf.float32)
else:
# Now fall back onto the heavier-weight processing, as all else failed.
# Use make_tensor_proto() to make sure to handle it consistently with
# how TensorFlow is handling values (e.g., recognizing int as int32, as
# opposed to int64 as in NumPy).
try:
# TODO(b/113112885): Find something more lightweight we could use here.
tensor_proto = tf.make_tensor_proto(arg)
return computation_types.TensorType(
tf.dtypes.as_dtype(tensor_proto.dtype),
tf.TensorShape(tensor_proto.tensor_shape))
except TypeError as err:
raise TypeError('Could not infer the TFF type of {}: {}'.format(
py_typecheck.type_string(type(arg)), err))