in tensorflow_federated/python/core/impl/compiler/tree_analysis.py [0:0]
def trees_equal(comp_1, comp_2):
"""Returns `True` if the computations are structurally equivalent.
Equivalent computations with different operation orderings are
not considered to be equal, but computations which are equivalent up to
reference renaming are. If either argument is `None`, returns true if and
only if both arguments are `None`. Note that this is the desired semantics
here, since `None` can appear as the argument to a `building_blocks.Call`
and therefore is considered a valid tree.
Note that the equivalence relation here is also known as alpha equivalence
in lambda calculus.
Args:
comp_1: A `building_blocks.ComputationBuildingBlock` to test.
comp_2: A `building_blocks.ComputationBuildingBlock` to test.
Returns:
`True` exactly when the computations are structurally equal, up to
renaming.
Raises:
TypeError: If `comp_1` or `comp_2` is not an instance of
`building_blocks.ComputationBuildingBlock`.
NotImplementedError: If `comp_1` and `comp_2` are an unexpected subclass of
`building_blocks.ComputationBuildingBlock`.
"""
py_typecheck.check_type(
comp_1, (building_blocks.ComputationBuildingBlock, type(None)))
py_typecheck.check_type(
comp_2, (building_blocks.ComputationBuildingBlock, type(None)))
def _trees_equal(comp_1, comp_2, reference_equivalences):
"""Internal helper for `trees_equal`."""
# The unidiomatic-typecheck is intentional, for the purposes of equality
# this function requires that the types are identical and that a subclass
# will not be equal to its baseclass.
if comp_1 is None or comp_2 is None:
return comp_1 is None and comp_2 is None
if type(comp_1) != type(comp_2): # pylint: disable=unidiomatic-typecheck
return False
if comp_1.type_signature != comp_2.type_signature:
return False
if comp_1.is_block():
if len(comp_1.locals) != len(comp_2.locals):
return False
for (name_1, value_1), (name_2, value_2) in zip(comp_1.locals,
comp_2.locals):
if not _trees_equal(value_1, value_2, reference_equivalences):
return False
reference_equivalences.append((name_1, name_2))
return _trees_equal(comp_1.result, comp_2.result, reference_equivalences)
elif comp_1.is_call():
return (_trees_equal(comp_1.function, comp_2.function,
reference_equivalences) and
_trees_equal(comp_1.argument, comp_2.argument,
reference_equivalences))
elif comp_1.is_compiled_computation():
return _compiled_comp_equal(comp_1, comp_2)
elif comp_1.is_data():
return comp_1.uri == comp_2.uri
elif comp_1.is_intrinsic():
return comp_1.uri == comp_2.uri
elif comp_1.is_lambda():
if comp_1.parameter_type != comp_2.parameter_type:
return False
reference_equivalences.append(
(comp_1.parameter_name, comp_2.parameter_name))
return _trees_equal(comp_1.result, comp_2.result, reference_equivalences)
elif comp_1.is_placement():
return comp_1.uri == comp_2.uri
elif comp_1.is_reference():
for comp_1_candidate, comp_2_candidate in reversed(
reference_equivalences):
if comp_1.name == comp_1_candidate:
return comp_2.name == comp_2_candidate
return comp_1.name == comp_2.name
elif comp_1.is_selection():
return (comp_1.name == comp_2.name and
comp_1.index == comp_2.index and _trees_equal(
comp_1.source, comp_2.source, reference_equivalences))
elif comp_1.is_struct():
# The element names are checked as part of the `type_signature`.
if len(comp_1) != len(comp_2):
return False
for element_1, element_2 in zip(comp_1, comp_2):
if not _trees_equal(element_1, element_2, reference_equivalences):
return False
return True
raise NotImplementedError('Unexpected type found: {}.'.format(type(comp_1)))
return _trees_equal(comp_1, comp_2, [])