in tensorflow_federated/python/core/backends/mapreduce/forms.py [0:0]
def __init__(self,
initialize,
prepare,
work,
zero,
accumulate,
merge,
report,
secure_sum_bitwidth,
secure_sum_max_input,
secure_modular_sum_modulus,
update,
server_state_label=None,
client_data_label=None):
"""Constructs a representation of a MapReduce-like iterative process.
Note: All the computations supplied here as arguments must be TensorFlow
computations, i.e., instances of `tff.Computation` constructed by the
`tff.tf_computation` decorator/wrapper.
Args:
initialize: The computation that produces the initial server state.
prepare: The computation that prepares the input for the clients.
work: The client-side work computation.
zero: The computation that produces the initial state for accumulators.
accumulate: The computation that adds a client update to an accumulator.
merge: The computation to use for merging pairs of accumulators.
report: The computation that produces the final server-side aggregate for
the top level accumulator (the global update).
secure_sum_bitwidth: The computation that produces the bitwidth for
bitwidth-based secure sums.
secure_sum_max_input: The computation that produces the maximum input for
`max_input`-based secure sums.
secure_modular_sum_modulus: The computation that produces the modulus for
secure modular sums.
update: The computation that takes the global update and the server state
and produces the new server state, as well as server-side output.
server_state_label: Optional string label for the server state.
client_data_label: Optional string label for the client data.
Raises:
TypeError: If the Python or TFF types of the arguments are invalid or not
compatible with each other.
AssertionError: If the manner in which the given TensorFlow computations
are represented by TFF does not match what this code is expecting (this
is an internal error that requires code update).
"""
for label, comp in (
('initialize', initialize),
('prepare', prepare),
('work', work),
('zero', zero),
('accumulate', accumulate),
('merge', merge),
('report', report),
('secure_sum_bitwidth', secure_sum_bitwidth),
('secure_sum_max_input', secure_sum_max_input),
('secure_modular_sum_modulus', secure_modular_sum_modulus),
('update', update),
):
_check_tensorflow_computation(label, comp)
prepare_arg_type = prepare.type_signature.parameter
init_result_type = initialize.type_signature.result
if not _is_assignable_from_or_both_none(prepare_arg_type, init_result_type):
raise TypeError(
'The `prepare` computation expects an argument of type {}, '
'which does not match the result type {} of `initialize`.'.format(
prepare_arg_type, init_result_type))
_check_accepts_tuple('work', work, 2)
work_2nd_arg_type = work.type_signature.parameter[1]
prepare_result_type = prepare.type_signature.result
if not _is_assignable_from_or_both_none(work_2nd_arg_type,
prepare_result_type):
raise TypeError(
'The `work` computation expects an argument tuple with type {} as '
'the second element (the initial client state from the server), '
'which does not match the result type {} of `prepare`.'.format(
work_2nd_arg_type, prepare_result_type))
_check_returns_tuple('work', work, WORK_RESULT_LEN)
py_typecheck.check_len(accumulate.type_signature.parameter, 2)
accumulate.type_signature.parameter[0].check_assignable_from(
zero.type_signature.result)
accumulate_2nd_arg_type = accumulate.type_signature.parameter[1]
work_client_update_type = work.type_signature.result[WORK_UPDATE_INDEX]
if not _is_assignable_from_or_both_none(accumulate_2nd_arg_type,
work_client_update_type):
raise TypeError(
'The `accumulate` computation expects a second argument of type {}, '
'which does not match the expected {} as implied by the type '
'signature of `work`.'.format(accumulate_2nd_arg_type,
work_client_update_type))
accumulate.type_signature.parameter[0].check_assignable_from(
accumulate.type_signature.result)
py_typecheck.check_len(merge.type_signature.parameter, 2)
merge.type_signature.parameter[0].check_assignable_from(
accumulate.type_signature.result)
merge.type_signature.parameter[1].check_assignable_from(
accumulate.type_signature.result)
merge.type_signature.parameter[0].check_assignable_from(
merge.type_signature.result)
report.type_signature.parameter.check_assignable_from(
merge.type_signature.result)
expected_update_parameter_type = computation_types.to_type([
initialize.type_signature.result,
[
report.type_signature.result,
# Update takes in the post-summation values of secure aggregation.
work.type_signature.result[WORK_SECAGG_BITWIDTH_INDEX],
work.type_signature.result[WORK_SECAGG_MAX_INPUT_INDEX],
work.type_signature.result[WORK_SECAGG_MODULUS_INDEX],
],
])
if not _is_assignable_from_or_both_none(update.type_signature.parameter,
expected_update_parameter_type):
raise TypeError(
'The `update` computation expects an argument of type {}, '
'which does not match the expected {} as implied by the type '
'signatures of `initialize`, `report`, and `work`.'.format(
update.type_signature.parameter, expected_update_parameter_type))
_check_returns_tuple('update', update, 2)
updated_state_type = update.type_signature.result[0]
if not prepare_arg_type.is_assignable_from(updated_state_type):
raise TypeError(
'The `update` computation returns a result tuple whose first element '
f'(the updated state type of the server) is type:\n'
f'{updated_state_type}\n'
f'which is not assignable to the state parameter type of `prepare`:\n'
f'{prepare_arg_type}')
self._initialize = initialize
self._prepare = prepare
self._work = work
self._zero = zero
self._accumulate = accumulate
self._merge = merge
self._report = report
self._secure_sum_bitwidth = secure_sum_bitwidth
self._secure_sum_max_input = secure_sum_max_input
self._secure_modular_sum_modulus = secure_modular_sum_modulus
self._update = update
if server_state_label is not None:
py_typecheck.check_type(server_state_label, str)
self._server_state_label = server_state_label
if client_data_label is not None:
py_typecheck.check_type(client_data_label, str)
self._client_data_label = client_data_label