in tensorflow_federated/python/core/impl/compiler/tree_transformations.py [0:0]
def strip_placement(comp):
"""Strips `comp`'s placement, returning a non-federated computation.
For this function to complete successfully `comp` must:
1) contain at most one federated placement.
2) not contain intrinsics besides `apply`, `map`, `zip`, and `federated_value`
3) not contain `building_blocks.Data` of federated type.
Args:
comp: Instance of `building_blocks.ComputationBuildingBlock` satisfying the
assumptions above.
Returns:
A modified version of `comp` containing no intrinsics nor any federated
types or values.
Raises:
TypeError: If `comp` is not a building block.
ValueError: If conditions (1), (2), or (3) above are unsatisfied.
"""
py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
placement = None
name_generator = building_block_factory.unique_name_generator(comp)
def _ensure_single_placement(new_placement):
nonlocal placement
if placement is None:
placement = new_placement
elif placement != new_placement:
raise ValueError(
'Attempted to `strip_placement` from computation containing '
'multiple different placements.\n'
f'Found placements `{placement}` and `{new_placement}` in '
f'comp:\n{comp.compact_representation()}')
def _remove_placement_from_type(type_spec):
if type_spec.is_federated():
_ensure_single_placement(type_spec.placement)
return type_spec.member, True
else:
return type_spec, False
def _remove_reference_placement(comp):
"""Unwraps placement from references and updates unbound reference info."""
new_type, _ = type_transformations.transform_type_postorder(
comp.type_signature, _remove_placement_from_type)
return building_blocks.Reference(comp.name, new_type)
def _identity_function(arg_type):
"""Creates `lambda x: x` with argument type `arg_type`."""
arg_name = next(name_generator)
val = building_blocks.Reference(arg_name, arg_type)
lam = building_blocks.Lambda(arg_name, arg_type, val)
return lam
def _call_first_with_second_function(fn_type, arg_type):
"""Creates `lambda x: x[0](x[1])` with the provided ."""
arg_name = next(name_generator)
tuple_ref = building_blocks.Reference(arg_name, [fn_type, arg_type])
fn = building_blocks.Selection(tuple_ref, index=0)
arg = building_blocks.Selection(tuple_ref, index=1)
called_fn = building_blocks.Call(fn, arg)
return building_blocks.Lambda(arg_name, tuple_ref.type_signature, called_fn)
def _call_function(arg_type):
"""Creates `lambda x: x()` argument type `arg_type`."""
arg_name = next(name_generator)
arg_ref = building_blocks.Reference(arg_name, arg_type)
called_arg = building_blocks.Call(arg_ref, None)
return building_blocks.Lambda(arg_name, arg_type, called_arg)
def _replace_intrinsics_with_functions(comp):
"""Helper to remove intrinsics from the AST."""
tys = comp.type_signature
# These functions have no runtime behavior and only exist to adjust
# placement. They are replaced here with `lambda x: x`.
identities = [
intrinsic_defs.FEDERATED_ZIP_AT_SERVER.uri,
intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS.uri,
intrinsic_defs.FEDERATED_VALUE_AT_SERVER.uri,
intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS.uri,
]
if comp.uri in identities:
return _identity_function(tys.result.member)
# These functions all `map` a value and are replaced with
# `lambda args: args[0](args[1])
maps = [
intrinsic_defs.FEDERATED_MAP.uri,
intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri,
intrinsic_defs.FEDERATED_APPLY.uri,
]
if comp.uri in maps:
return _call_first_with_second_function(tys.parameter[0],
tys.parameter[1].member)
# `federated_eval`'s argument must simply be `call`ed and is replaced
# with `lambda x: x()`
evals = [
intrinsic_defs.FEDERATED_EVAL_AT_SERVER.uri,
intrinsic_defs.FEDERATED_EVAL_AT_CLIENTS.uri,
]
if comp.uri in evals:
return _call_function(tys.parameter)
raise ValueError('Disallowed intrinsic: {}'.format(comp))
def _remove_lambda_placement(comp):
"""Removes placement from Lambda's parameter."""
if comp.parameter_name is None:
new_parameter_type = None
else:
new_parameter_type, _ = type_transformations.transform_type_postorder(
comp.parameter_type, _remove_placement_from_type)
return building_blocks.Lambda(comp.parameter_name, new_parameter_type,
comp.result)
def _simplify_calls(comp):
"""Unwraps structures introduced by removing intrinsics."""
zip_or_value_removed = (
comp.function.result.is_reference() and
comp.function.result.name == comp.function.parameter_name)
if zip_or_value_removed:
return comp.argument
else:
map_removed = (
comp.function.result.is_call() and
comp.function.result.function.is_selection() and
comp.function.result.function.index == 0 and
comp.function.result.argument.is_selection() and
comp.function.result.argument.index == 1 and
comp.function.result.function.source.is_reference() and
comp.function.result.function.source.name
== comp.function.parameter_name and
comp.function.result.function.source.is_reference() and
comp.function.result.function.source.name
== comp.function.parameter_name and comp.argument.is_struct())
if map_removed:
return building_blocks.Call(comp.argument[0], comp.argument[1])
return comp
def _transform(comp):
"""Dispatches to helpers above."""
if comp.is_reference():
return _remove_reference_placement(comp), True
elif comp.is_intrinsic():
return _replace_intrinsics_with_functions(comp), True
elif comp.is_lambda():
return _remove_lambda_placement(comp), True
elif comp.is_call() and comp.function.is_lambda():
return _simplify_calls(comp), True
elif comp.is_data() and comp.type_signature.is_federated():
raise ValueError(f'Cannot strip placement from federated data: {comp}')
return comp, False
return transformation_utils.transform_postorder(comp, _transform)