in tensorflow_federated/python/aggregators/encoded.py [0:0]
def _encoded_next_fn(server_state_type, value_type, encoders):
"""Creates `next_fn` for the process returned by `EncodedSumFactory`.
The structure of the implementation is roughly as follows:
* Extract params for encoding/decoding from state (`get_params_fn`).
* Encode values to be aggregated, placed at clients (`encode_fn`).
* Call `federated_aggregate` operator, with decoding of the part which does
not commute with sum, placed in its `accumulate_fn` arg.
* Finish decoding the summed value placed at server (`decode_after_sum_fn`).
* Update the state placed at server (`update_state_fn`).
Args:
server_state_type: A `tff.Type` of the expected state placed at server.
value_type: An unplaced `tff.Type` of the value to be aggregated.
encoders: A collection of `GatherEncoder` objects.
Returns:
A `tff.Computation` for `EncodedSumFactory`, with the type signature of
`(server_state_type, value_type@CLIENTS) ->
MeasuredProcessOutput(server_state_type, value_type@SERVER, ()@SERVER)`
"""
@computations.tf_computation(server_state_type.member)
def get_params_fn(state):
params = tree.map_structure_up_to(encoders, lambda e, s: e.get_params(s),
encoders, state)
encode_params = _slice(encoders, params, 0)
decode_before_sum_params = _slice(encoders, params, 1)
decode_after_sum_params = _slice(encoders, params, 2)
return encode_params, decode_before_sum_params, decode_after_sum_params
encode_params_type = get_params_fn.type_signature.result[0]
decode_before_sum_params_type = get_params_fn.type_signature.result[1]
decode_after_sum_params_type = get_params_fn.type_signature.result[2]
# TODO(b/139844355): Get rid of decode_before_sum_params.
# We pass decode_before_sum_params to the encode method, because TFF currently
# does not have a mechanism to make a tff.SERVER placed value available inside
# of intrinsics.federated_aggregate - in production, this could mean an
# intermediary aggregator node. So currently, we send the params to clients,
# and ask them to send them back as part of the encoded structure.
@computations.tf_computation(value_type, encode_params_type,
decode_before_sum_params_type)
def encode_fn(x, encode_params, decode_before_sum_params):
encoded_structure = tree.map_structure_up_to(
encoders, lambda e, *args: e.encode(*args), encoders, x, encode_params)
encoded_x = _slice(encoders, encoded_structure, 0)
state_update_tensors = _slice(encoders, encoded_structure, 1)
return encoded_x, decode_before_sum_params, state_update_tensors
state_update_tensors_type = encode_fn.type_signature.result[2]
# This is not a @computations.tf_computation because it will be used below
# when bulding the computations.tf_computations that will compose a
# intrinsics.federated_aggregate...
def decode_before_sum_tf_function(encoded_x, decode_before_sum_params):
part_decoded_x = tree.map_structure_up_to(
encoders, lambda e, *args: e.decode_before_sum(*args), encoders,
encoded_x, decode_before_sum_params)
one = tf.constant((1,), tf.int32)
return part_decoded_x, one
# ...however, result type is needed to build the subsequent tf_compuations.
@computations.tf_computation(encode_fn.type_signature.result[0:2])
def tmp_decode_before_sum_fn(encoded_x, decode_before_sum_params):
return decode_before_sum_tf_function(encoded_x, decode_before_sum_params)
part_decoded_x_type = tmp_decode_before_sum_fn.type_signature.result
del tmp_decode_before_sum_fn # Only needed for result type.
@computations.tf_computation(part_decoded_x_type,
decode_after_sum_params_type)
def decode_after_sum_fn(summed_values, decode_after_sum_params):
part_decoded_aggregated_x, num_summands = summed_values
return tree.map_structure_up_to(
encoders,
lambda e, x, params: e.decode_after_sum(x, params, num_summands),
encoders, part_decoded_aggregated_x, decode_after_sum_params)
@computations.tf_computation(server_state_type.member,
state_update_tensors_type)
def update_state_fn(state, state_update_tensors):
return tree.map_structure_up_to(encoders,
lambda e, *args: e.update_state(*args),
encoders, state, state_update_tensors)
# Computations for intrinsics.federated_aggregate.
def _accumulator_value(values, state_update_tensors):
return collections.OrderedDict(
values=values, state_update_tensors=state_update_tensors)
@computations.tf_computation
def zero_fn():
values = tf.nest.map_structure(
lambda s: tf.zeros(s.shape, s.dtype),
type_conversions.type_to_tf_tensor_specs(part_decoded_x_type))
state_update_tensors = tf.nest.map_structure(
lambda s: tf.zeros(s.shape, s.dtype),
type_conversions.type_to_tf_tensor_specs(state_update_tensors_type))
return _accumulator_value(values, state_update_tensors)
accumulator_type = zero_fn.type_signature.result
state_update_aggregation_modes = tf.nest.map_structure(
lambda e: tuple(e.state_update_aggregation_modes), encoders)
@computations.tf_computation(accumulator_type,
encode_fn.type_signature.result)
def accumulate_fn(acc, encoded_x):
value, params, state_update_tensors = encoded_x
part_decoded_value = decode_before_sum_tf_function(value, params)
new_values = tf.nest.map_structure(tf.add, acc['values'],
part_decoded_value)
new_state_update_tensors = tf.nest.map_structure(
_accmulate_state_update_tensor, acc['state_update_tensors'],
state_update_tensors, state_update_aggregation_modes)
return _accumulator_value(new_values, new_state_update_tensors)
@computations.tf_computation(accumulator_type, accumulator_type)
def merge_fn(acc1, acc2):
new_values = tf.nest.map_structure(tf.add, acc1['values'], acc2['values'])
new_state_update_tensors = tf.nest.map_structure(
_accmulate_state_update_tensor, acc1['state_update_tensors'],
acc2['state_update_tensors'], state_update_aggregation_modes)
return _accumulator_value(new_values, new_state_update_tensors)
@computations.tf_computation(accumulator_type)
def report_fn(acc):
return acc
@computations.federated_computation(server_state_type,
computation_types.at_clients(value_type))
def next_fn(state, value):
encode_params, decode_before_sum_params, decode_after_sum_params = (
intrinsics.federated_map(get_params_fn, state))
encode_params = intrinsics.federated_broadcast(encode_params)
decode_before_sum_params = intrinsics.federated_broadcast(
decode_before_sum_params)
encoded_values = intrinsics.federated_map(
encode_fn, [value, encode_params, decode_before_sum_params])
aggregated_values = intrinsics.federated_aggregate(encoded_values,
zero_fn(), accumulate_fn,
merge_fn, report_fn)
decoded_values = intrinsics.federated_map(
decode_after_sum_fn,
[aggregated_values.values, decode_after_sum_params])
updated_state = intrinsics.federated_map(
update_state_fn, [state, aggregated_values.state_update_tensors])
empty_metrics = intrinsics.federated_value((), placements.SERVER)
return measured_process.MeasuredProcessOutput(
state=updated_state, result=decoded_values, measurements=empty_metrics)
return next_fn