in tensorflow_federated/python/core/impl/types/type_transformations.py [0:0]
def transform_type_postorder(
type_signature: computation_types.Type,
transform_fn: Callable[[computation_types.Type],
Tuple[computation_types.Type, bool]]):
"""Walks type tree of `type_signature` postorder, calling `transform_fn`.
Args:
type_signature: Instance of `computation_types.Type` to transform
recursively.
transform_fn: Transformation function to apply to each node in the type tree
of `type_signature`. Must be instance of Python function type.
Returns:
A possibly transformed version of `type_signature`, with each node in its
tree the result of applying `transform_fn` to the corresponding node in
`type_signature`.
Raises:
TypeError: If the types don't match the specification above.
"""
py_typecheck.check_type(type_signature, computation_types.Type)
py_typecheck.check_callable(transform_fn)
if type_signature.is_federated():
transformed_member, member_mutated = transform_type_postorder(
type_signature.member, transform_fn)
if member_mutated:
type_signature = computation_types.FederatedType(transformed_member,
type_signature.placement,
type_signature.all_equal)
type_signature, type_signature_mutated = transform_fn(type_signature)
return type_signature, type_signature_mutated or member_mutated
elif type_signature.is_sequence():
transformed_element, element_mutated = transform_type_postorder(
type_signature.element, transform_fn)
if element_mutated:
type_signature = computation_types.SequenceType(transformed_element)
type_signature, type_signature_mutated = transform_fn(type_signature)
return type_signature, type_signature_mutated or element_mutated
elif type_signature.is_function():
if type_signature.parameter is not None:
transformed_parameter, parameter_mutated = transform_type_postorder(
type_signature.parameter, transform_fn)
else:
transformed_parameter, parameter_mutated = (None, False)
transformed_result, result_mutated = transform_type_postorder(
type_signature.result, transform_fn)
if parameter_mutated or result_mutated:
type_signature = computation_types.FunctionType(transformed_parameter,
transformed_result)
type_signature, type_signature_mutated = transform_fn(type_signature)
return type_signature, (
type_signature_mutated or parameter_mutated or result_mutated)
elif type_signature.is_struct():
elements = []
elements_mutated = False
for element in structure.iter_elements(type_signature):
transformed_element, element_mutated = transform_type_postorder(
element[1], transform_fn)
elements_mutated = elements_mutated or element_mutated
elements.append((element[0], transformed_element))
if elements_mutated:
if type_signature.is_struct_with_python():
type_signature = computation_types.StructWithPythonType(
elements, type_signature.python_container)
else:
type_signature = computation_types.StructType(elements)
type_signature, type_signature_mutated = transform_fn(type_signature)
return type_signature, type_signature_mutated or elements_mutated
elif type_signature.is_abstract() or type_signature.is_placement(
) or type_signature.is_tensor():
return transform_fn(type_signature)