in tensorflow_federated/python/core/impl/execution_contexts/mergeable_comp_execution_context.py [0:0]
def __init__(self, *, up_to_merge: computation_base.Computation,
merge: computation_base.Computation,
after_merge: computation_base.Computation):
if not (up_to_merge.type_signature.result.is_federated() and
up_to_merge.type_signature.result.placement.is_server()):
raise UpToMergeTypeError(
'Expected `up_to_merge` to return a single `tff.SERVER`-placed '
f'value; found return type {up_to_merge.type_signature.result}.')
# TFF's StructType assignability relation ensures that an unnamed struct can
# be assigned to any struct with names.
expected_merge_param_type = computation_types.StructType([
(None, up_to_merge.type_signature.result.member),
(None, up_to_merge.type_signature.result.member)
])
if not merge.type_signature.parameter.is_assignable_from(
expected_merge_param_type):
raise MergeTypeNotAssignableError(
'Type mismatch checking `merge` type signature.\n' +
computation_types.type_mismatch_error_message(
merge.type_signature.parameter,
expected_merge_param_type,
computation_types.TypeRelation.ASSIGNABLE,
second_is_expected=True))
if not (merge.type_signature.parameter[0].is_assignable_from(
merge.type_signature.result) and
merge.type_signature.parameter[1].is_assignable_from(
merge.type_signature.result)):
raise MergeTypeNotAssignableError(
'Expected `merge` to have result which is assignable to '
'each element of its parameter tuple; found parameter '
f'of type: \n{merge.type_signature.parameter}\nAnd result of type: \n'
f'{merge.type_signature.result}')
if up_to_merge.type_signature.parameter is not None:
# TODO(b/147499373): If None arguments were uniformly represented as empty
# tuples, we could avoid this and related ugly if/else casing.
expected_after_merge_arg_type = computation_types.StructType([
(None, up_to_merge.type_signature.parameter),
(None, computation_types.at_server(merge.type_signature.result)),
])
else:
expected_after_merge_arg_type = computation_types.at_server(
merge.type_signature.result)
after_merge.type_signature.parameter.check_assignable_from(
expected_after_merge_arg_type)
def _federated_type_predicate(
type_signature: computation_types.Type,
placement: placements.PlacementLiteral) -> bool:
return (type_signature.is_federated() and
type_signature.placement == placement)
def _moves_clients_to_server_predicate(
intrinsic: building_blocks.Intrinsic):
parameter_contains_clients_placement = type_analysis.contains(
intrinsic.type_signature.parameter,
lambda x: _federated_type_predicate(x, placements.CLIENTS))
result_contains_server_placement = type_analysis.contains(
intrinsic.type_signature.result,
lambda x: _federated_type_predicate(x, placements.SERVER))
return (parameter_contains_clients_placement and
result_contains_server_placement)
aggregations = set()
def _aggregation_predicate(
comp: building_blocks.ComputationBuildingBlock) -> bool:
if not comp.is_intrinsic():
return False
if not comp.type_signature.is_function():
return False
if _moves_clients_to_server_predicate(comp):
aggregations.add((comp.uri, comp.type_signature))
return True
return False
# We only know how to statically analyze computations which are backed by
# computation.protos; to avoid opening up a visibility hole that isn't
# technically necessary here, we prefer to simply skip the static check here
# for computations which cannot convert themselves to building blocks.
if hasattr(after_merge, 'to_building_block') and tree_analysis.contains(
after_merge.to_building_block(), _aggregation_predicate):
formatted_aggregations = ', '.join(
'{}: {}'.format(elem[0], elem[1]) for elem in aggregations)
raise AfterMergeStructureError(
'Expected `after_merge` to contain no intrinsics '
'with signatures accepting values at clients and '
'returning values at server. Found the following '
f'aggregations: {formatted_aggregations}')
self.up_to_merge = up_to_merge
self.merge = merge
self.after_merge = after_merge