def _encoded_next_fn()

in tensorflow_federated/python/aggregators/encoded.py [0:0]


def _encoded_next_fn(server_state_type, value_type, encoders):
  """Creates `next_fn` for the process returned by `EncodedSumFactory`.

  The structure of the implementation is roughly as follows:
  * Extract params for encoding/decoding from state (`get_params_fn`).
  * Encode values to be aggregated, placed at clients (`encode_fn`).
  * Call `federated_aggregate` operator, with decoding of the part which does
    not commute with sum, placed in its `accumulate_fn` arg.
  * Finish decoding the summed value placed at server (`decode_after_sum_fn`).
  * Update the state placed at server (`update_state_fn`).

  Args:
    server_state_type: A `tff.Type` of the expected state placed at server.
    value_type: An unplaced `tff.Type` of the value to be aggregated.
    encoders: A collection of `GatherEncoder` objects.

  Returns:
    A `tff.Computation` for `EncodedSumFactory`, with the type signature of
    `(server_state_type, value_type@CLIENTS) ->
    MeasuredProcessOutput(server_state_type, value_type@SERVER, ()@SERVER)`
  """

  @computations.tf_computation(server_state_type.member)
  def get_params_fn(state):
    params = tree.map_structure_up_to(encoders, lambda e, s: e.get_params(s),
                                      encoders, state)
    encode_params = _slice(encoders, params, 0)
    decode_before_sum_params = _slice(encoders, params, 1)
    decode_after_sum_params = _slice(encoders, params, 2)
    return encode_params, decode_before_sum_params, decode_after_sum_params

  encode_params_type = get_params_fn.type_signature.result[0]
  decode_before_sum_params_type = get_params_fn.type_signature.result[1]
  decode_after_sum_params_type = get_params_fn.type_signature.result[2]

  # TODO(b/139844355): Get rid of decode_before_sum_params.
  # We pass decode_before_sum_params to the encode method, because TFF currently
  # does not have a mechanism to make a tff.SERVER placed value available inside
  # of intrinsics.federated_aggregate - in production, this could mean an
  # intermediary aggregator node. So currently, we send the params to clients,
  # and ask them to send them back as part of the encoded structure.
  @computations.tf_computation(value_type, encode_params_type,
                               decode_before_sum_params_type)
  def encode_fn(x, encode_params, decode_before_sum_params):
    encoded_structure = tree.map_structure_up_to(
        encoders, lambda e, *args: e.encode(*args), encoders, x, encode_params)
    encoded_x = _slice(encoders, encoded_structure, 0)
    state_update_tensors = _slice(encoders, encoded_structure, 1)
    return encoded_x, decode_before_sum_params, state_update_tensors

  state_update_tensors_type = encode_fn.type_signature.result[2]

  # This is not a @computations.tf_computation because it will be used below
  # when bulding the computations.tf_computations that will compose a
  # intrinsics.federated_aggregate...
  def decode_before_sum_tf_function(encoded_x, decode_before_sum_params):
    part_decoded_x = tree.map_structure_up_to(
        encoders, lambda e, *args: e.decode_before_sum(*args), encoders,
        encoded_x, decode_before_sum_params)
    one = tf.constant((1,), tf.int32)
    return part_decoded_x, one

  # ...however, result type is needed to build the subsequent tf_compuations.
  @computations.tf_computation(encode_fn.type_signature.result[0:2])
  def tmp_decode_before_sum_fn(encoded_x, decode_before_sum_params):
    return decode_before_sum_tf_function(encoded_x, decode_before_sum_params)

  part_decoded_x_type = tmp_decode_before_sum_fn.type_signature.result
  del tmp_decode_before_sum_fn  # Only needed for result type.

  @computations.tf_computation(part_decoded_x_type,
                               decode_after_sum_params_type)
  def decode_after_sum_fn(summed_values, decode_after_sum_params):
    part_decoded_aggregated_x, num_summands = summed_values
    return tree.map_structure_up_to(
        encoders,
        lambda e, x, params: e.decode_after_sum(x, params, num_summands),
        encoders, part_decoded_aggregated_x, decode_after_sum_params)

  @computations.tf_computation(server_state_type.member,
                               state_update_tensors_type)
  def update_state_fn(state, state_update_tensors):
    return tree.map_structure_up_to(encoders,
                                    lambda e, *args: e.update_state(*args),
                                    encoders, state, state_update_tensors)

  # Computations for intrinsics.federated_aggregate.
  def _accumulator_value(values, state_update_tensors):
    return collections.OrderedDict(
        values=values, state_update_tensors=state_update_tensors)

  @computations.tf_computation
  def zero_fn():
    values = tf.nest.map_structure(
        lambda s: tf.zeros(s.shape, s.dtype),
        type_conversions.type_to_tf_tensor_specs(part_decoded_x_type))
    state_update_tensors = tf.nest.map_structure(
        lambda s: tf.zeros(s.shape, s.dtype),
        type_conversions.type_to_tf_tensor_specs(state_update_tensors_type))
    return _accumulator_value(values, state_update_tensors)

  accumulator_type = zero_fn.type_signature.result
  state_update_aggregation_modes = tf.nest.map_structure(
      lambda e: tuple(e.state_update_aggregation_modes), encoders)

  @computations.tf_computation(accumulator_type,
                               encode_fn.type_signature.result)
  def accumulate_fn(acc, encoded_x):
    value, params, state_update_tensors = encoded_x
    part_decoded_value = decode_before_sum_tf_function(value, params)
    new_values = tf.nest.map_structure(tf.add, acc['values'],
                                       part_decoded_value)
    new_state_update_tensors = tf.nest.map_structure(
        _accmulate_state_update_tensor, acc['state_update_tensors'],
        state_update_tensors, state_update_aggregation_modes)
    return _accumulator_value(new_values, new_state_update_tensors)

  @computations.tf_computation(accumulator_type, accumulator_type)
  def merge_fn(acc1, acc2):
    new_values = tf.nest.map_structure(tf.add, acc1['values'], acc2['values'])
    new_state_update_tensors = tf.nest.map_structure(
        _accmulate_state_update_tensor, acc1['state_update_tensors'],
        acc2['state_update_tensors'], state_update_aggregation_modes)
    return _accumulator_value(new_values, new_state_update_tensors)

  @computations.tf_computation(accumulator_type)
  def report_fn(acc):
    return acc

  @computations.federated_computation(server_state_type,
                                      computation_types.at_clients(value_type))
  def next_fn(state, value):
    encode_params, decode_before_sum_params, decode_after_sum_params = (
        intrinsics.federated_map(get_params_fn, state))
    encode_params = intrinsics.federated_broadcast(encode_params)
    decode_before_sum_params = intrinsics.federated_broadcast(
        decode_before_sum_params)

    encoded_values = intrinsics.federated_map(
        encode_fn, [value, encode_params, decode_before_sum_params])

    aggregated_values = intrinsics.federated_aggregate(encoded_values,
                                                       zero_fn(), accumulate_fn,
                                                       merge_fn, report_fn)

    decoded_values = intrinsics.federated_map(
        decode_after_sum_fn,
        [aggregated_values.values, decode_after_sum_params])

    updated_state = intrinsics.federated_map(
        update_state_fn, [state, aggregated_values.state_update_tensors])

    empty_metrics = intrinsics.federated_value((), placements.SERVER)
    return measured_process.MeasuredProcessOutput(
        state=updated_state, result=decoded_values, measurements=empty_metrics)

  return next_fn