def infer_type()

in tensorflow_federated/python/core/impl/types/type_conversions.py [0:0]


def infer_type(arg: Any) -> Optional[computation_types.Type]:
  """Infers the TFF type of the argument (a `computation_types.Type` instance).

  WARNING: This function is only partially implemented.

  The kinds of arguments that are currently correctly recognized:
  - tensors, variables, and data sets,
  - things that are convertible to tensors (including numpy arrays, builtin
    types, as well as lists and tuples of any of the above, etc.),
  - nested lists, tuples, namedtuples, anonymous tuples, dict, and OrderedDicts.

  Args:
    arg: The argument, the TFF type of which to infer.

  Returns:
    Either an instance of `computation_types.Type`, or `None` if the argument is
    `None`.
  """
  if arg is None:
    return None
  elif isinstance(arg, typed_object.TypedObject):
    return arg.type_signature
  elif tf.is_tensor(arg):
    # `tf.is_tensor` returns true for some things that are not actually single
    # `tf.Tensor`s, including `tf.SparseTensor`s and `tf.RaggedTensor`s.
    if isinstance(arg, tf.RaggedTensor):
      return computation_types.StructWithPythonType(
          (('flat_values', infer_type(arg.flat_values)),
           ('nested_row_splits', infer_type(arg.nested_row_splits))),
          tf.RaggedTensor)
    elif isinstance(arg, tf.SparseTensor):
      return computation_types.StructWithPythonType(
          (('indices', infer_type(arg.indices)),
           ('values', infer_type(arg.values)),
           ('dense_shape', infer_type(arg.dense_shape))), tf.SparseTensor)
    else:
      return computation_types.TensorType(arg.dtype.base_dtype, arg.shape)
  elif isinstance(arg, TF_DATASET_REPRESENTATION_TYPES):
    element_type = computation_types.to_type(arg.element_spec)
    return computation_types.SequenceType(element_type)
  elif isinstance(arg, structure.Struct):
    return computation_types.StructType([
        (k, infer_type(v)) if k else infer_type(v)
        for k, v in structure.iter_elements(arg)
    ])
  elif py_typecheck.is_attrs(arg):
    items = attr.asdict(
        arg, dict_factory=collections.OrderedDict, recurse=False)
    return computation_types.StructWithPythonType(
        [(k, infer_type(v)) for k, v in items.items()], type(arg))
  elif py_typecheck.is_named_tuple(arg):
    # In Python 3.8 and later `_asdict` no longer return OrdereDict, rather a
    # regular `dict`.
    items = collections.OrderedDict(arg._asdict())
    return computation_types.StructWithPythonType(
        [(k, infer_type(v)) for k, v in items.items()], type(arg))
  elif isinstance(arg, dict):
    if isinstance(arg, collections.OrderedDict):
      items = arg.items()
    else:
      items = sorted(arg.items())
    return computation_types.StructWithPythonType(
        [(k, infer_type(v)) for k, v in items], type(arg))
  elif isinstance(arg, (tuple, list)):
    elements = []
    all_elements_named = True
    for element in arg:
      all_elements_named &= py_typecheck.is_name_value_pair(element)
      elements.append(infer_type(element))
    # If this is a tuple of (name, value) pairs, the caller most likely intended
    # this to be a StructType, so we avoid storing the Python container.
    if elements and all_elements_named:
      return computation_types.StructType(elements)
    else:
      return computation_types.StructWithPythonType(elements, type(arg))
  elif isinstance(arg, str):
    return computation_types.TensorType(tf.string)
  elif isinstance(arg, (np.generic, np.ndarray)):
    return computation_types.TensorType(
        tf.dtypes.as_dtype(arg.dtype), arg.shape)
  else:
    arg_type = type(arg)
    if arg_type is bool:
      return computation_types.TensorType(tf.bool)
    elif arg_type is int:
      # Chose the integral type based on value.
      if arg > tf.int64.max or arg < tf.int64.min:
        raise TypeError('No integral type support for values outside range '
                        f'[{tf.int64.min}, {tf.int64.max}]. Got: {arg}')
      elif arg > tf.int32.max or arg < tf.int32.min:
        return computation_types.TensorType(tf.int64)
      else:
        return computation_types.TensorType(tf.int32)
    elif arg_type is float:
      return computation_types.TensorType(tf.float32)
    else:
      # Now fall back onto the heavier-weight processing, as all else failed.
      # Use make_tensor_proto() to make sure to handle it consistently with
      # how TensorFlow is handling values (e.g., recognizing int as int32, as
      # opposed to int64 as in NumPy).
      try:
        # TODO(b/113112885): Find something more lightweight we could use here.
        tensor_proto = tf.make_tensor_proto(arg)
        return computation_types.TensorType(
            tf.dtypes.as_dtype(tensor_proto.dtype),
            tf.TensorShape(tensor_proto.tensor_shape))
      except TypeError as err:
        raise TypeError('Could not infer the TFF type of {}: {}'.format(
            py_typecheck.type_string(type(arg)), err))